diff --git a/.circleci/config.yml b/.circleci/config.yml index 32b4077e93d..a2e06e55399 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -190,10 +190,10 @@ jobs: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off flake8 typing - flake8 --config=setup.cfg . + - run: pip install --user --progress-bar=off isort==5.* black==21.6b0 flake8 + - run: isort --settings-path=setup.cfg --check-only . + - run: black --config=black.toml --check . + - run: flake8 --config=setup.cfg . python_type_check: docker: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 15cb7eb6a07..79c97b49f49 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -190,10 +190,10 @@ jobs: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off flake8 typing - flake8 --config=setup.cfg . + - run: pip install --user --progress-bar=off isort==5.* black==21.6b0 flake8 + - run: isort --settings-path=setup.cfg --check-only . + - run: black --config=black.toml --check . + - run: flake8 --config=setup.cfg . python_type_check: docker: diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py index ce7cf4cedbb..4be5659a195 100755 --- a/.circleci/regenerate.py +++ b/.circleci/regenerate.py @@ -14,68 +14,77 @@ https://github.com/pytorch/vision/pull/1321#issuecomment-531033978 """ -import jinja2 -from jinja2 import select_autoescape -import yaml import os.path +import jinja2 +import yaml +from jinja2 import select_autoescape PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] RC_PATTERN = r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" -def build_workflows(prefix='', filter_branch=None, upload=False, indentation=6, windows_latest_only=False): +def build_workflows(prefix="", filter_branch=None, upload=False, indentation=6, windows_latest_only=False): w = [] for btype in ["wheel", "conda"]: for os_type in ["linux", "macos", "win"]: python_versions = PYTHON_VERSIONS - cu_versions_dict = {"linux": ["cpu", "cu102", "cu111", "rocm4.1", "rocm4.2"], - "win": ["cpu", "cu102", "cu111"], - "macos": ["cpu"]} + cu_versions_dict = { + "linux": ["cpu", "cu102", "cu111", "rocm4.1", "rocm4.2"], + "win": ["cpu", "cu102", "cu111"], + "macos": ["cpu"], + } cu_versions = cu_versions_dict[os_type] for python_version in python_versions: for cu_version in cu_versions: # ROCm conda packages not yet supported - if cu_version.startswith('rocm') and btype == "conda": + if cu_version.startswith("rocm") and btype == "conda": continue for unicode in [False]: fb = filter_branch - if windows_latest_only and os_type == "win" and filter_branch is None and \ - (python_version != python_versions[-1] or - (cu_version not in [cu_versions[0], cu_versions[-1]])): + if ( + windows_latest_only + and os_type == "win" + and filter_branch is None + and ( + python_version != python_versions[-1] + or (cu_version not in [cu_versions[0], cu_versions[-1]]) + ) + ): fb = "master" - if not fb and (os_type == 'linux' and - cu_version == 'cpu' and - btype == 'wheel' and - python_version == '3.7'): + if not fb and ( + os_type == "linux" and cu_version == "cpu" and btype == "wheel" and python_version == "3.7" + ): # the fields must match the build_docs "requires" dependency fb = "/.*/" w += workflow_pair( - btype, os_type, python_version, cu_version, - unicode, prefix, upload, filter_branch=fb) + btype, os_type, python_version, cu_version, unicode, prefix, upload, filter_branch=fb + ) if not filter_branch: # Build on every pull request, but upload only on nightly and tags - w += build_doc_job('/.*/') - w += upload_doc_job('nightly') + w += build_doc_job("/.*/") + w += upload_doc_job("nightly") return indent(indentation, w) -def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix='', upload=False, *, filter_branch=None): +def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix="", upload=False, *, filter_branch=None): w = [] unicode_suffix = "u" if unicode else "" base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}{unicode_suffix}_{cu_version}" - w.append(generate_base_workflow( - base_workflow_name, python_version, cu_version, - unicode, os_type, btype, filter_branch=filter_branch)) + w.append( + generate_base_workflow( + base_workflow_name, python_version, cu_version, unicode, os_type, btype, filter_branch=filter_branch + ) + ) if upload: w.append(generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, filter_branch=filter_branch)) - if filter_branch == 'nightly' and os_type in ['linux', 'win']: - pydistro = 'pip' if btype == 'wheel' else 'conda' + if filter_branch == "nightly" and os_type in ["linux", "win"]: + pydistro = "pip" if btype == "wheel" else "conda" w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, os_type)) return w @@ -85,12 +94,13 @@ def build_doc_job(filter_branch): job = { "name": "build_docs", "python_version": "3.7", - "requires": ["binary_linux_wheel_py3.7_cpu", ], + "requires": [ + "binary_linux_wheel_py3.7_cpu", + ], } if filter_branch: - job["filters"] = gen_filter_branch_tree(filter_branch, - tags_list=RC_PATTERN) + job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN) return [{"build_docs": job}] @@ -99,12 +109,13 @@ def upload_doc_job(filter_branch): "name": "upload_docs", "context": "org-member", "python_version": "3.7", - "requires": ["build_docs", ], + "requires": [ + "build_docs", + ], } if filter_branch: - job["filters"] = gen_filter_branch_tree(filter_branch, - tags_list=RC_PATTERN) + job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN) return [{"upload_docs": job}] @@ -121,24 +132,25 @@ def upload_doc_job(filter_branch): def get_manylinux_image(cu_version): if cu_version == "cpu": return "pytorch/manylinux-cuda102" - elif cu_version.startswith('cu'): - cu_suffix = cu_version[len('cu'):] + elif cu_version.startswith("cu"): + cu_suffix = cu_version[len("cu") :] return f"pytorch/manylinux-cuda{cu_suffix}" - elif cu_version.startswith('rocm'): - rocm_suffix = cu_version[len('rocm'):] + elif cu_version.startswith("rocm"): + rocm_suffix = cu_version[len("rocm") :] return f"pytorch/manylinux-rocm:{rocm_suffix}" def get_conda_image(cu_version): if cu_version == "cpu": return "pytorch/conda-builder:cpu" - elif cu_version.startswith('cu'): - cu_suffix = cu_version[len('cu'):] + elif cu_version.startswith("cu"): + cu_suffix = cu_version[len("cu") :] return f"pytorch/conda-builder:cuda{cu_suffix}" -def generate_base_workflow(base_workflow_name, python_version, cu_version, - unicode, os_type, btype, *, filter_branch=None): +def generate_base_workflow( + base_workflow_name, python_version, cu_version, unicode, os_type, btype, *, filter_branch=None +): d = { "name": base_workflow_name, @@ -147,7 +159,7 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, } if os_type != "win" and unicode: - d["unicode_abi"] = '1' + d["unicode_abi"] = "1" if os_type != "win": d["wheel_docker_image"] = get_manylinux_image(cu_version) @@ -157,14 +169,12 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, if filter_branch is not None: d["filters"] = { - "branches": { - "only": filter_branch - }, + "branches": {"only": filter_branch}, "tags": { # Using a raw string here to avoid having to escape # anything "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" - } + }, } w = f"binary_{os_type}_{btype}" @@ -185,19 +195,17 @@ def generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, *, "requires": [base_workflow_name], } - if btype == 'wheel': - d["subfolder"] = "" if os_type == 'macos' else cu_version + "/" + if btype == "wheel": + d["subfolder"] = "" if os_type == "macos" else cu_version + "/" if filter_branch is not None: d["filters"] = { - "branches": { - "only": filter_branch - }, + "branches": {"only": filter_branch}, "tags": { # Using a raw string here to avoid having to escape # anything "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" - } + }, } return {f"binary_{btype}_upload": d} @@ -222,8 +230,7 @@ def generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, pyt def indent(indentation, data_list): - return ("\n" + " " * indentation).join( - yaml.dump(data_list, default_flow_style=False).splitlines()) + return ("\n" + " " * indentation).join(yaml.dump(data_list, default_flow_style=False).splitlines()) def unittest_workflows(indentation=6): @@ -238,12 +245,12 @@ def unittest_workflows(indentation=6): "python_version": python_version, } - if device_type == 'gpu': + if device_type == "gpu": if python_version != "3.8": - job['filters'] = gen_filter_branch_tree('master', 'nightly') - job['cu_version'] = 'cu102' + job["filters"] = gen_filter_branch_tree("master", "nightly") + job["cu_version"] = "cu102" else: - job['cu_version'] = 'cpu' + job["cu_version"] = "cpu" jobs.append({f"unittest_{os_type}_{device_type}": job}) @@ -252,20 +259,17 @@ def unittest_workflows(indentation=6): def cmake_workflows(indentation=6): jobs = [] - python_version = '3.8' - for os_type in ['linux', 'windows', 'macos']: + python_version = "3.8" + for os_type in ["linux", "windows", "macos"]: # Skip OSX CUDA - device_types = ['cpu', 'gpu'] if os_type != 'macos' else ['cpu'] + device_types = ["cpu", "gpu"] if os_type != "macos" else ["cpu"] for device in device_types: - job = { - 'name': f'cmake_{os_type}_{device}', - 'python_version': python_version - } + job = {"name": f"cmake_{os_type}_{device}", "python_version": python_version} - job['cu_version'] = 'cu102' if device == 'gpu' else 'cpu' - if device == 'gpu' and os_type == 'linux': - job['wheel_docker_image'] = 'pytorch/manylinux-cuda102' - jobs.append({f'cmake_{os_type}_{device}': job}) + job["cu_version"] = "cu102" if device == "gpu" else "cpu" + if device == "gpu" and os_type == "linux": + job["wheel_docker_image"] = "pytorch/manylinux-cuda102" + jobs.append({f"cmake_{os_type}_{device}": job}) return indent(indentation, jobs) @@ -274,27 +278,27 @@ def ios_workflows(indentation=6, nightly=False): build_job_names = [] name_prefix = "nightly_" if nightly else "" env_prefix = "nightly-" if nightly else "" - for arch, platform in [('x86_64', 'SIMULATOR'), ('arm64', 'OS')]: - name = f'{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}' + for arch, platform in [("x86_64", "SIMULATOR"), ("arm64", "OS")]: + name = f"{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}" build_job_names.append(name) build_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}', - 'ios_arch': arch, - 'ios_platform': platform, - 'name': name, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}", + "ios_arch": arch, + "ios_platform": platform, + "name": name, } if nightly: - build_job['filters'] = gen_filter_branch_tree('nightly') - jobs.append({'binary_ios_build': build_job}) + build_job["filters"] = gen_filter_branch_tree("nightly") + jobs.append({"binary_ios_build": build_job}) if nightly: upload_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload', - 'context': 'org-member', - 'filters': gen_filter_branch_tree('nightly'), - 'requires': build_job_names, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload", + "context": "org-member", + "filters": gen_filter_branch_tree("nightly"), + "requires": build_job_names, } - jobs.append({'binary_ios_upload': upload_job}) + jobs.append({"binary_ios_upload": upload_job}) return indent(indentation, jobs) @@ -304,23 +308,23 @@ def android_workflows(indentation=6, nightly=False): name_prefix = "nightly_" if nightly else "" env_prefix = "nightly-" if nightly else "" - name = f'{name_prefix}binary_libtorchvision_ops_android' + name = f"{name_prefix}binary_libtorchvision_ops_android" build_job_names.append(name) build_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-android', - 'name': name, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-android", + "name": name, } if nightly: upload_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-android-upload', - 'context': 'org-member', - 'filters': gen_filter_branch_tree('nightly'), - 'name': f'{name_prefix}binary_libtorchvision_ops_android_upload' + "build_environment": f"{env_prefix}binary-libtorchvision_ops-android-upload", + "context": "org-member", + "filters": gen_filter_branch_tree("nightly"), + "name": f"{name_prefix}binary_libtorchvision_ops_android_upload", } - jobs.append({'binary_android_upload': upload_job}) + jobs.append({"binary_android_upload": upload_job}) else: - jobs.append({'binary_android_build': build_job}) + jobs.append({"binary_android_build": build_job}) return indent(indentation, jobs) @@ -329,15 +333,17 @@ def android_workflows(indentation=6, nightly=False): env = jinja2.Environment( loader=jinja2.FileSystemLoader(d), lstrip_blocks=True, - autoescape=select_autoescape(enabled_extensions=('html', 'xml')), + autoescape=select_autoescape(enabled_extensions=("html", "xml")), keep_trailing_newline=True, ) - with open(os.path.join(d, 'config.yml'), 'w') as f: - f.write(env.get_template('config.yml.in').render( - build_workflows=build_workflows, - unittest_workflows=unittest_workflows, - cmake_workflows=cmake_workflows, - ios_workflows=ios_workflows, - android_workflows=android_workflows, - )) + with open(os.path.join(d, "config.yml"), "w") as f: + f.write( + env.get_template("config.yml.in").render( + build_workflows=build_workflows, + unittest_workflows=unittest_workflows, + cmake_workflows=cmake_workflows, + ios_workflows=ios_workflows, + android_workflows=android_workflows, + ) + ) diff --git a/.circleci/unittest/linux/scripts/run-clang-format.py b/.circleci/unittest/linux/scripts/run-clang-format.py index 7bbd1acd0f4..bf313ec8253 100755 --- a/.circleci/unittest/linux/scripts/run-clang-format.py +++ b/.circleci/unittest/linux/scripts/run-clang-format.py @@ -32,7 +32,6 @@ """ import argparse -import codecs import difflib import fnmatch import io @@ -42,7 +41,6 @@ import subprocess import sys import traceback - from functools import partial try: @@ -51,7 +49,7 @@ DEVNULL = open(os.devnull, "wb") -DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu' +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" class ExitStatus: @@ -75,14 +73,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None): # os.walk() supports trimming down the dnames list # by modifying it in-place, # to avoid unnecessary directory listings. - dnames[:] = [ - x for x in dnames - if - not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) - ] - fpaths = [ - x for x in fpaths if not fnmatch.fnmatch(x, pattern) - ] + dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] for f in fpaths: ext = os.path.splitext(f)[1][1:] if ext in extensions: @@ -95,11 +87,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None): def make_diff(file, original, reformatted): return list( difflib.unified_diff( - original, - reformatted, - fromfile='{}\t(original)'.format(file), - tofile='{}\t(reformatted)'.format(file), - n=3)) + original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3 + ) + ) class DiffError(Exception): @@ -122,13 +112,12 @@ def run_clang_format_diff_wrapper(args, file): except DiffError: raise except Exception as e: - raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__, - e), e) + raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e) def run_clang_format_diff(args, file): try: - with io.open(file, 'r', encoding='utf-8') as f: + with io.open(file, "r", encoding="utf-8") as f: original = f.readlines() except IOError as exc: raise DiffError(str(exc)) @@ -153,17 +142,10 @@ def run_clang_format_diff(args, file): try: proc = subprocess.Popen( - invocation, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - encoding='utf-8') - except OSError as exc: - raise DiffError( - "Command '{}' failed to start: {}".format( - subprocess.list2cmdline(invocation), exc - ) + invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8" ) + except OSError as exc: + raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc)) proc_stdout = proc.stdout proc_stderr = proc.stderr @@ -182,30 +164,30 @@ def run_clang_format_diff(args, file): def bold_red(s): - return '\x1b[1m\x1b[31m' + s + '\x1b[0m' + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" def colorize(diff_lines): def bold(s): - return '\x1b[1m' + s + '\x1b[0m' + return "\x1b[1m" + s + "\x1b[0m" def cyan(s): - return '\x1b[36m' + s + '\x1b[0m' + return "\x1b[36m" + s + "\x1b[0m" def green(s): - return '\x1b[32m' + s + '\x1b[0m' + return "\x1b[32m" + s + "\x1b[0m" def red(s): - return '\x1b[31m' + s + '\x1b[0m' + return "\x1b[31m" + s + "\x1b[0m" for line in diff_lines: - if line[:4] in ['--- ', '+++ ']: + if line[:4] in ["--- ", "+++ "]: yield bold(line) - elif line.startswith('@@ '): + elif line.startswith("@@ "): yield cyan(line) - elif line.startswith('+'): + elif line.startswith("+"): yield green(line) - elif line.startswith('-'): + elif line.startswith("-"): yield red(line) else: yield line @@ -218,7 +200,7 @@ def print_diff(diff_lines, use_color): def print_trouble(prog, message, use_colors): - error_text = 'error:' + error_text = "error:" if use_colors: error_text = bold_red(error_text) print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) @@ -227,45 +209,37 @@ def print_trouble(prog, message, use_colors): def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--clang-format-executable', - metavar='EXECUTABLE', - help='path to the clang-format executable', - default='clang-format') - parser.add_argument( - '--extensions', - help='comma separated list of file extensions (default: {})'.format( - DEFAULT_EXTENSIONS), - default=DEFAULT_EXTENSIONS) + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) parser.add_argument( - '-r', - '--recursive', - action='store_true', - help='run recursively over directories') - parser.add_argument('files', metavar='file', nargs='+') + "--extensions", + help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS), + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories") + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") parser.add_argument( - '-q', - '--quiet', - action='store_true') - parser.add_argument( - '-j', - metavar='N', + "-j", + metavar="N", type=int, default=0, - help='run N clang-format jobs in parallel' - ' (default number of cpus + 1)') + help="run N clang-format jobs in parallel" " (default number of cpus + 1)", + ) parser.add_argument( - '--color', - default='auto', - choices=['auto', 'always', 'never'], - help='show colored diff (default: auto)') + "--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)" + ) parser.add_argument( - '-e', - '--exclude', - metavar='PATTERN', - action='append', + "-e", + "--exclude", + metavar="PATTERN", + action="append", default=[], - help='exclude paths matching the given glob-like pattern(s)' - ' from recursive search') + help="exclude paths matching the given glob-like pattern(s)" " from recursive search", + ) args = parser.parse_args() @@ -282,10 +256,10 @@ def main(): colored_stdout = False colored_stderr = False - if args.color == 'always': + if args.color == "always": colored_stdout = True colored_stderr = True - elif args.color == 'auto': + elif args.color == "auto": colored_stdout = sys.stdout.isatty() colored_stderr = sys.stderr.isatty() @@ -298,19 +272,15 @@ def main(): except OSError as e: print_trouble( parser.prog, - "Command '{}' failed to start: {}".format( - subprocess.list2cmdline(version_invocation), e - ), + "Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e), use_colors=colored_stderr, ) return ExitStatus.TROUBLE retcode = ExitStatus.SUCCESS files = list_files( - args.files, - recursive=args.recursive, - exclude=args.exclude, - extensions=args.extensions.split(',')) + args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",") + ) if not files: return @@ -327,8 +297,7 @@ def main(): pool = None else: pool = multiprocessing.Pool(njobs) - it = pool.imap_unordered( - partial(run_clang_format_diff_wrapper, args), files) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) while True: try: outs, errs = next(it) @@ -359,5 +328,5 @@ def main(): return retcode -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index 65f805ce471..68903afed60 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -30,7 +30,7 @@ jobs: run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Install torchvision - run: pip install -e . + run: pip install --no-build-isolation --editable . - name: Install all optional dataset requirements run: pip install scipy pandas pycocotools lmdb requests diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py index 7860c759a57..fedee39fc52 100644 --- a/android/test_app/make_assets.py +++ b/android/test_app/make_assets.py @@ -5,11 +5,8 @@ print(torch.__version__) model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, - box_score_thresh=0.7, - rpn_post_nms_top_n_test=100, - rpn_score_thresh=0.4, - rpn_pre_nms_top_n_test=150) + pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +) model.eval() script_model = torch.jit.script(model) diff --git a/black.toml b/black.toml new file mode 100644 index 00000000000..eb96429c25b --- /dev/null +++ b/black.toml @@ -0,0 +1,12 @@ +# We are diverging from the standard pyproject.toml here since simply having a a pyproject.toml file in the root +# folder changes the default behavior of pip. To use this configuration, run black with --config=black.toml. + +[tool.black] +# See link below for available options +# https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file + +line-length = 120 +target-version = ["py36"] + +exclude = "gallery" +extend_exclude = true \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 6bbb05c13c7..485ae297e98 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,9 +21,12 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) -import torchvision import pytorch_sphinx_theme +from docutils import nodes +from sphinx import addnodes +from sphinx.util.docfields import TypedField +import torchvision # -- General configuration ------------------------------------------------ @@ -33,24 +36,24 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.duration', - 'sphinx_gallery.gen_gallery', - 'sphinx_copybutton', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.duration", + "sphinx_gallery.gen_gallery", + "sphinx_copybutton", ] sphinx_gallery_conf = { - 'examples_dirs': '../../gallery/', # path to your example scripts - 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output - 'backreferences_dir': 'gen_modules/backreferences', - 'doc_module': ('torchvision',), + "examples_dirs": "../../gallery/", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "backreferences_dir": "gen_modules/backreferences", + "doc_module": ("torchvision",), } napoleon_use_ivar = True @@ -59,22 +62,22 @@ # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { - '.rst': 'restructuredtext', + ".rst": "restructuredtext", } # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'Torchvision' -copyright = '2017-present, Torch Contributors' -author = 'Torch Contributors' +project = "Torchvision" +copyright = "2017-present, Torch Contributors" +author = "Torch Contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -82,10 +85,10 @@ # # The short X.Y version. # TODO: change to [:2] at v1.0 -version = 'master (' + torchvision.__version__ + ' )' +version = "master (" + torchvision.__version__ + " )" # The full version, including alpha/beta/rc tags. # TODO: verify this works as expected -release = 'master' +release = "master" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -100,7 +103,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -111,7 +114,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme @@ -119,30 +122,30 @@ # documentation. # html_theme_options = { - 'collapse_navigation': False, - 'display_version': True, - 'logo_only': True, - 'pytorch_project': 'docs', - 'navigation_with_keys': True, - 'analytics_id': 'UA-117752657-2', + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "pytorch_project": "docs", + "navigation_with_keys": True, + "analytics_id": "UA-117752657-2", } -html_logo = '_static/img/pytorch-logo-dark.svg' +html_logo = "_static/img/pytorch-logo-dark.svg" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed html_css_files = [ - 'css/custom_torchvision.css', + "css/custom_torchvision.css", ] # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'PyTorchdoc' +htmlhelp_basename = "PyTorchdoc" # -- Options for LaTeX output --------------------------------------------- @@ -150,15 +153,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -169,8 +169,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'pytorch.tex', 'torchvision Documentation', - 'Torch Contributors', 'manual'), + (master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"), ] @@ -178,10 +177,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'torchvision', 'torchvision Documentation', - [author], 1) -] +man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -190,28 +186,30 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'torchvision', 'torchvision Documentation', - author, 'torchvision', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "torchvision", + "torchvision Documentation", + author, + "torchvision", + "One line description of project.", + "Miscellaneous", + ), ] # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), - 'matplotlib': ('https://matplotlib.org/stable/', None), + "python": ("https://docs.python.org/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "PIL": ("https://pillow.readthedocs.io/en/stable/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), } # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # See http://stackoverflow.com/a/41184353/3343043 -from docutils import nodes -from sphinx.util.docfields import TypedField -from sphinx import addnodes - def patched_make_field(self, types, domain, items, **kw): # `kw` catches `env=None` needed for newer sphinx while maintaining @@ -220,40 +218,39 @@ def patched_make_field(self, types, domain, items, **kw): # type: (list, unicode, tuple) -> nodes.field # noqa: F821 def handle_item(fieldarg, content): par = nodes.paragraph() - par += addnodes.literal_strong('', fieldarg) # Patch: this line added + par += addnodes.literal_strong("", fieldarg) # Patch: this line added # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # addnodes.literal_strong)) if fieldarg in types: - par += nodes.Text(' (') + par += nodes.Text(" (") # NOTE: using .pop() here to prevent a single type node to be # inserted twice into the doctree, which leads to # inconsistencies later when references are resolved fieldtype = types.pop(fieldarg) if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = u''.join(n.astext() for n in fieldtype) - typename = typename.replace('int', 'python:int') - typename = typename.replace('long', 'python:long') - typename = typename.replace('float', 'python:float') - typename = typename.replace('type', 'python:type') - par.extend(self.make_xrefs(self.typerolename, domain, typename, - addnodes.literal_emphasis, **kw)) + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("type", "python:type") + par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw)) else: par += fieldtype - par += nodes.Text(')') - par += nodes.Text(' -- ') + par += nodes.Text(")") + par += nodes.Text(" -- ") par += content return par - fieldname = nodes.field_name('', self.label) + fieldname = nodes.field_name("", self.label) if len(items) == 1 and self.can_collapse: fieldarg, content = items[0] bodynode = handle_item(fieldarg, content) else: bodynode = self.list_type() for fieldarg, content in items: - bodynode += nodes.list_item('', handle_item(fieldarg, content)) - fieldbody = nodes.field_body('', bodynode) - return nodes.field('', fieldname, fieldbody) + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) TypedField.make_field = patched_make_field @@ -286,4 +283,4 @@ def inject_minigalleries(app, what, name, obj, options, lines): def setup(app): - app.connect('autodoc-process-docstring', inject_minigalleries) + app.connect("autodoc-process-docstring", inject_minigalleries) diff --git a/hubconf.py b/hubconf.py index 097759bdd89..4a4affb1677 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,21 +1,35 @@ # Optional list of dependencies required by the package -dependencies = ['torch'] - # classification from torchvision.models.alexnet import alexnet -from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 -from torchvision.models.inception import inception_v3 -from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\ - resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 -from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 -from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn +from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201 from torchvision.models.googlenet import googlenet -from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 +from torchvision.models.inception import inception_v3 +from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3 from torchvision.models.mobilenetv2 import mobilenet_v2 from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small -from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ - mnasnet1_3 +from torchvision.models.resnet import ( + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnext50_32x4d, + resnext101_32x8d, + wide_resnet50_2, + wide_resnet101_2, +) # segmentation -from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ - deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large +from torchvision.models.segmentation import ( + deeplabv3_mobilenet_v3_large, + deeplabv3_resnet50, + deeplabv3_resnet101, + fcn_resnet50, + fcn_resnet101, + lraspp_mobilenet_v3_large, +) +from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 +from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 +from torchvision.models.vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn + +dependencies = ["torch"] diff --git a/ios/VisionTestApp/make_assets.py b/ios/VisionTestApp/make_assets.py index 122094b3547..0f46364569b 100644 --- a/ios/VisionTestApp/make_assets.py +++ b/ios/VisionTestApp/make_assets.py @@ -5,11 +5,8 @@ print(torch.__version__) model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, - box_score_thresh=0.7, - rpn_post_nms_top_n_test=100, - rpn_score_thresh=0.4, - rpn_pre_nms_top_n_test=150) + pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +) model.eval() script_model = torch.jit.script(model) diff --git a/packaging/wheel/relocate.py b/packaging/wheel/relocate.py index dd2c5d2a4ce..3a94d3a58c1 100644 --- a/packaging/wheel/relocate.py +++ b/packaging/wheel/relocate.py @@ -2,46 +2,62 @@ """Helper script to package wheels and relocate binaries.""" -# Standard library imports -import os -import io -import sys import glob -import shutil -import zipfile import hashlib +import io + +# Standard library imports +import os +import os.path as osp import platform +import shutil import subprocess -import os.path as osp +import sys +import zipfile from base64 import urlsafe_b64encode # Third party imports -if sys.platform == 'linux': +if sys.platform == "linux": from auditwheel.lddtree import lddtree -from wheel.bdist_wheel import get_abi_tag ALLOWLIST = { - 'libgcc_s.so.1', 'libstdc++.so.6', 'libm.so.6', - 'libdl.so.2', 'librt.so.1', 'libc.so.6', - 'libnsl.so.1', 'libutil.so.1', 'libpthread.so.0', - 'libresolv.so.2', 'libX11.so.6', 'libXext.so.6', - 'libXrender.so.1', 'libICE.so.6', 'libSM.so.6', - 'libGL.so.1', 'libgobject-2.0.so.0', 'libgthread-2.0.so.0', - 'libglib-2.0.so.0', 'ld-linux-x86-64.so.2', 'ld-2.17.so' + "libgcc_s.so.1", + "libstdc++.so.6", + "libm.so.6", + "libdl.so.2", + "librt.so.1", + "libc.so.6", + "libnsl.so.1", + "libutil.so.1", + "libpthread.so.0", + "libresolv.so.2", + "libX11.so.6", + "libXext.so.6", + "libXrender.so.1", + "libICE.so.6", + "libSM.so.6", + "libGL.so.1", + "libgobject-2.0.so.0", + "libgthread-2.0.so.0", + "libglib-2.0.so.0", + "ld-linux-x86-64.so.2", + "ld-2.17.so", } WINDOWS_ALLOWLIST = { - 'MSVCP140.dll', 'KERNEL32.dll', - 'VCRUNTIME140_1.dll', 'VCRUNTIME140.dll', - 'api-ms-win-crt-heap-l1-1-0.dll', - 'api-ms-win-crt-runtime-l1-1-0.dll', - 'api-ms-win-crt-stdio-l1-1-0.dll', - 'api-ms-win-crt-filesystem-l1-1-0.dll', - 'api-ms-win-crt-string-l1-1-0.dll', - 'api-ms-win-crt-environment-l1-1-0.dll', - 'api-ms-win-crt-math-l1-1-0.dll', - 'api-ms-win-crt-convert-l1-1-0.dll' + "MSVCP140.dll", + "KERNEL32.dll", + "VCRUNTIME140_1.dll", + "VCRUNTIME140.dll", + "api-ms-win-crt-heap-l1-1-0.dll", + "api-ms-win-crt-runtime-l1-1-0.dll", + "api-ms-win-crt-stdio-l1-1-0.dll", + "api-ms-win-crt-filesystem-l1-1-0.dll", + "api-ms-win-crt-string-l1-1-0.dll", + "api-ms-win-crt-environment-l1-1-0.dll", + "api-ms-win-crt-math-l1-1-0.dll", + "api-ms-win-crt-convert-l1-1-0.dll", } @@ -64,20 +80,18 @@ def rehash(path, blocksize=1 << 20): """Return (hash, length) for path using hashlib.sha256()""" h = hashlib.sha256() length = 0 - with open(path, 'rb') as f: + with open(path, "rb") as f: for block in read_chunks(f, size=blocksize): length += len(block) h.update(block) - digest = 'sha256=' + urlsafe_b64encode( - h.digest() - ).decode('latin1').rstrip('=') + digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=") # unicode/str python2 issues return (digest, str(length)) # type: ignore def unzip_file(file, dest): """Decompress zip `file` into directory `dest`.""" - with zipfile.ZipFile(file, 'r') as zip_ref: + with zipfile.ZipFile(file, "r") as zip_ref: zip_ref.extractall(dest) @@ -88,8 +102,7 @@ def is_program_installed(basename): On macOS systems, a .app is considered installed if it exists. """ - if (sys.platform == 'darwin' and basename.endswith('.app') and - osp.exists(basename)): + if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename): return basename for path in os.environ["PATH"].split(os.pathsep): @@ -105,9 +118,9 @@ def find_program(basename): (return None if not found) """ names = [basename] - if os.name == 'nt': + if os.name == "nt": # Windows platforms - extensions = ('.exe', '.bat', '.cmd', '.dll') + extensions = (".exe", ".bat", ".cmd", ".dll") if not basename.endswith(extensions): names = [basename + ext for ext in extensions] + [basename] for name in names: @@ -118,19 +131,18 @@ def find_program(basename): def patch_new_path(library_path, new_dir): library = osp.basename(library_path) - name, *rest = library.split('.') - rest = '.'.join(rest) - hash_id = hashlib.sha256(library_path.encode('utf-8')).hexdigest()[:8] - new_name = '.'.join([name, hash_id, rest]) + name, *rest = library.split(".") + rest = ".".join(rest) + hash_id = hashlib.sha256(library_path.encode("utf-8")).hexdigest()[:8] + new_name = ".".join([name, hash_id, rest]) return osp.join(new_dir, new_name) def find_dll_dependencies(dumpbin, binary): - out = subprocess.run([dumpbin, "/dependents", binary], - stdout=subprocess.PIPE) - out = out.stdout.strip().decode('utf-8') - start_index = out.find('dependencies:') + len('dependencies:') - end_index = out.find('Summary') + out = subprocess.run([dumpbin, "/dependents", binary], stdout=subprocess.PIPE) + out = out.stdout.strip().decode("utf-8") + start_index = out.find("dependencies:") + len("dependencies:") + end_index = out.find("Summary") dlls = out[start_index:end_index].strip() dlls = dlls.split(os.linesep) dlls = [dll.strip() for dll in dlls] @@ -145,13 +157,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): rename and copy them into the wheel while updating their respective rpaths. """ - print('Relocating {0}'.format(binary)) + print("Relocating {0}".format(binary)) binary_path = osp.join(output_library, binary) ld_tree = lddtree(binary_path) - tree_libs = ld_tree['libs'] + tree_libs = ld_tree["libs"] - binary_queue = [(n, binary) for n in ld_tree['needed']] + binary_queue = [(n, binary) for n in ld_tree["needed"]] binary_paths = {binary: binary_path} binary_dependencies = {} @@ -160,13 +172,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): library_info = tree_libs[library] print(library) - if library_info['path'] is None: - print('Omitting {0}'.format(library)) + if library_info["path"] is None: + print("Omitting {0}".format(library)) continue if library in ALLOWLIST: # Omit glibc/gcc/system libraries - print('Omitting {0}'.format(library)) + print("Omitting {0}".format(library)) continue parent_dependencies = binary_dependencies.get(parent, []) @@ -176,11 +188,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): if library in binary_paths: continue - binary_paths[library] = library_info['path'] - binary_queue += [(n, library) for n in library_info['needed']] + binary_paths[library] = library_info["path"] + binary_queue += [(n, library) for n in library_info["needed"]] - print('Copying dependencies to wheel directory') - new_libraries_path = osp.join(output_dir, 'torchvision.libs') + print("Copying dependencies to wheel directory") + new_libraries_path = osp.join(output_dir, "torchvision.libs") os.makedirs(new_libraries_path) new_names = {binary: binary_path} @@ -189,11 +201,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): if library != binary: library_path = binary_paths[library] new_library_path = patch_new_path(library_path, new_libraries_path) - print('{0} -> {1}'.format(library, new_library_path)) + print("{0} -> {1}".format(library, new_library_path)) shutil.copyfile(library_path, new_library_path) new_names[library] = new_library_path - print('Updating dependency names by new files') + print("Updating dependency names by new files") for library in binary_paths: if library != binary: if library not in binary_dependencies: @@ -202,59 +214,26 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): new_library_name = new_names[library] for dep in library_dependencies: new_dep = osp.basename(new_names[dep]) - print('{0}: {1} -> {2}'.format(library, dep, new_dep)) + print("{0}: {1} -> {2}".format(library, dep, new_dep)) subprocess.check_output( - [ - patchelf, - '--replace-needed', - dep, - new_dep, - new_library_name - ], - cwd=new_libraries_path) - - print('Updating library rpath') - subprocess.check_output( - [ - patchelf, - '--set-rpath', - "$ORIGIN", - new_library_name - ], - cwd=new_libraries_path) - - subprocess.check_output( - [ - patchelf, - '--print-rpath', - new_library_name - ], - cwd=new_libraries_path) + [patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path + ) + + print("Updating library rpath") + subprocess.check_output([patchelf, "--set-rpath", "$ORIGIN", new_library_name], cwd=new_libraries_path) + + subprocess.check_output([patchelf, "--print-rpath", new_library_name], cwd=new_libraries_path) print("Update library dependencies") library_dependencies = binary_dependencies[binary] for dep in library_dependencies: new_dep = osp.basename(new_names[dep]) - print('{0}: {1} -> {2}'.format(binary, dep, new_dep)) - subprocess.check_output( - [ - patchelf, - '--replace-needed', - dep, - new_dep, - binary - ], - cwd=output_library) - - print('Update library rpath') + print("{0}: {1} -> {2}".format(binary, dep, new_dep)) + subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library) + + print("Update library rpath") subprocess.check_output( - [ - patchelf, - '--set-rpath', - "$ORIGIN:$ORIGIN/../torchvision.libs", - binary_path - ], - cwd=output_library + [patchelf, "--set-rpath", "$ORIGIN:$ORIGIN/../torchvision.libs", binary_path], cwd=output_library ) @@ -265,7 +244,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): Given a shared library, find the transitive closure of its dependencies, rename and copy them into the wheel. """ - print('Relocating {0}'.format(binary)) + print("Relocating {0}".format(binary)) binary_path = osp.join(output_library, binary) library_dlls = find_dll_dependencies(dumpbin, binary_path) @@ -275,19 +254,19 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): while binary_queue != []: library, parent = binary_queue.pop(0) - if library in WINDOWS_ALLOWLIST or library.startswith('api-ms-win'): - print('Omitting {0}'.format(library)) + if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"): + print("Omitting {0}".format(library)) continue library_path = find_program(library) if library_path is None: - print('{0} not found'.format(library)) + print("{0} not found".format(library)) continue - if osp.basename(osp.dirname(library_path)) == 'system32': + if osp.basename(osp.dirname(library_path)) == "system32": continue - print('{0}: {1}'.format(library, library_path)) + print("{0}: {1}".format(library, library_path)) parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies.append(library) binary_dependencies[parent] = parent_dependencies @@ -299,55 +278,56 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): downstream_dlls = find_dll_dependencies(dumpbin, library_path) binary_queue += [(n, library) for n in downstream_dlls] - print('Copying dependencies to wheel directory') - package_dir = osp.join(output_dir, 'torchvision') + print("Copying dependencies to wheel directory") + package_dir = osp.join(output_dir, "torchvision") for library in binary_paths: if library != binary: library_path = binary_paths[library] new_library_path = osp.join(package_dir, library) - print('{0} -> {1}'.format(library, new_library_path)) + print("{0} -> {1}".format(library, new_library_path)) shutil.copyfile(library_path, new_library_path) def compress_wheel(output_dir, wheel, wheel_dir, wheel_name): """Create RECORD file and compress wheel distribution.""" - print('Update RECORD file in wheel') - dist_info = glob.glob(osp.join(output_dir, '*.dist-info'))[0] - record_file = osp.join(dist_info, 'RECORD') + print("Update RECORD file in wheel") + dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0] + record_file = osp.join(dist_info, "RECORD") - with open(record_file, 'w') as f: + with open(record_file, "w") as f: for root, _, files in os.walk(output_dir): for this_file in files: full_file = osp.join(root, this_file) rel_file = osp.relpath(full_file, output_dir) if full_file == record_file: - f.write('{0},,\n'.format(rel_file)) + f.write("{0},,\n".format(rel_file)) else: digest, size = rehash(full_file) - f.write('{0},{1},{2}\n'.format(rel_file, digest, size)) + f.write("{0},{1},{2}\n".format(rel_file, digest, size)) - print('Compressing wheel') + print("Compressing wheel") base_wheel_name = osp.join(wheel_dir, wheel_name) - shutil.make_archive(base_wheel_name, 'zip', output_dir) + shutil.make_archive(base_wheel_name, "zip", output_dir) os.remove(wheel) - shutil.move('{0}.zip'.format(base_wheel_name), wheel) + shutil.move("{0}.zip".format(base_wheel_name), wheel) shutil.rmtree(output_dir) def patch_linux(): # Get patchelf location - patchelf = find_program('patchelf') + patchelf = find_program("patchelf") if patchelf is None: - raise FileNotFoundError('Patchelf was not found in the system, please' - ' make sure that is available on the PATH.') + raise FileNotFoundError( + "Patchelf was not found in the system, please" " make sure that is available on the PATH." + ) # Find wheel - print('Finding wheels...') - wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) - output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') + print("Finding wheels...") + wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl")) + output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process") - image_binary = 'image.so' - video_binary = 'video_reader.so' + image_binary = "image.so" + video_binary = "video_reader.so" torchvision_binaries = [image_binary, video_binary] for wheel in wheels: if osp.exists(output_dir): @@ -355,37 +335,37 @@ def patch_linux(): os.makedirs(output_dir) - print('Unzipping wheel...') + print("Unzipping wheel...") wheel_file = osp.basename(wheel) wheel_dir = osp.dirname(wheel) - print('{0}'.format(wheel_file)) + print("{0}".format(wheel_file)) wheel_name, _ = osp.splitext(wheel_file) unzip_file(wheel, output_dir) - print('Finding ELF dependencies...') - output_library = osp.join(output_dir, 'torchvision') + print("Finding ELF dependencies...") + output_library = osp.join(output_dir, "torchvision") for binary in torchvision_binaries: if osp.exists(osp.join(output_library, binary)): - relocate_elf_library( - patchelf, output_dir, output_library, binary) + relocate_elf_library(patchelf, output_dir, output_library, binary) compress_wheel(output_dir, wheel, wheel_dir, wheel_name) def patch_win(): # Get dumpbin location - dumpbin = find_program('dumpbin') + dumpbin = find_program("dumpbin") if dumpbin is None: - raise FileNotFoundError('Dumpbin was not found in the system, please' - ' make sure that is available on the PATH.') + raise FileNotFoundError( + "Dumpbin was not found in the system, please" " make sure that is available on the PATH." + ) # Find wheel - print('Finding wheels...') - wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) - output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') + print("Finding wheels...") + wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl")) + output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process") - image_binary = 'image.pyd' - video_binary = 'video_reader.pyd' + image_binary = "image.pyd" + video_binary = "video_reader.pyd" torchvision_binaries = [image_binary, video_binary] for wheel in wheels: if osp.exists(output_dir): @@ -393,25 +373,24 @@ def patch_win(): os.makedirs(output_dir) - print('Unzipping wheel...') + print("Unzipping wheel...") wheel_file = osp.basename(wheel) wheel_dir = osp.dirname(wheel) - print('{0}'.format(wheel_file)) + print("{0}".format(wheel_file)) wheel_name, _ = osp.splitext(wheel_file) unzip_file(wheel, output_dir) - print('Finding DLL/PE dependencies...') - output_library = osp.join(output_dir, 'torchvision') + print("Finding DLL/PE dependencies...") + output_library = osp.join(output_dir, "torchvision") for binary in torchvision_binaries: if osp.exists(osp.join(output_library, binary)): - relocate_dll_library( - dumpbin, output_dir, output_library, binary) + relocate_dll_library(dumpbin, output_dir, output_library, binary) compress_wheel(output_dir, wheel, wheel_dir, wheel_name) -if __name__ == '__main__': - if sys.platform == 'linux': +if __name__ == "__main__": + if sys.platform == "linux": patch_linux() - elif sys.platform == 'win32': + elif sys.platform == "win32": patch_win() diff --git a/references/classification/presets.py b/references/classification/presets.py index 6bb389ba8db..3a7bfe1eaa0 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -2,18 +2,27 @@ class ClassificationPresetTrain: - def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, - auto_augment_policy=None, random_erase_prob=0.0): + def __init__( + self, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + hflip_prob=0.5, + auto_augment_policy=None, + random_erase_prob=0.0, + ): trans = [transforms.RandomResizedCrop(crop_size)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy)) - trans.extend([ - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) + trans.extend( + [ + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) @@ -26,12 +35,14 @@ def __call__(self, img): class ClassificationPresetEval: def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = transforms.Compose([ - transforms.Resize(resize_size), - transforms.CenterCrop(crop_size), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index b4e9d274662..a02f95f6dd9 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -2,13 +2,13 @@ import os import time +import presets +import utils + import torch import torch.utils.data -from torch import nn import torchvision - -import presets -import utils +from torch import nn try: from apex import amp @@ -19,10 +19,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) - header = 'Epoch: [{}]'.format(epoch) + header = "Epoch: [{}]".format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() image, target = image.to(device), target.to(device) @@ -40,15 +40,15 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) def evaluate(model, criterion, data_loader, device, print_freq=100): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -61,18 +61,22 @@ def evaluate(model, criterion, data_loader, device, print_freq=100): # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() - print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' - .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + print( + " * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}".format( + top1=metric_logger.acc1, top5=metric_logger.acc5 + ) + ) return metric_logger.acc1.global_avg def _get_cache_path(filepath): import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) @@ -82,7 +86,7 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224) + resize_size, crop_size = (342, 299) if args.model == "inception_v3" else (256, 224) print("Loading training data") st = time.time() @@ -96,8 +100,10 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, - presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy, - random_erase_prob=random_erase_prob)) + presets.ClassificationPresetTrain( + crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob + ), + ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) @@ -112,8 +118,8 @@ def load_data(traindir, valdir, args): dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( - valdir, - presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) + valdir, presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size) + ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) @@ -132,8 +138,10 @@ def load_data(traindir, valdir, args): def main(args): if args.apex and amp is None: - raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training.") + raise RuntimeError( + "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training." + ) if args.output_dir: utils.mkdir(args.output_dir) @@ -145,16 +153,16 @@ def main(args): torch.backends.cudnn.benchmark = True - train_dir = os.path.join(args.data_path, 'train') - val_dir = os.path.join(args.data_path, 'val') + train_dir = os.path.join(args.data_path, "train") + val_dir = os.path.join(args.data_path, "val") dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) + dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, - sampler=test_sampler, num_workers=args.workers, pin_memory=True) + dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + ) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) @@ -165,19 +173,24 @@ def main(args): criterion = nn.CrossEntropyLoss() opt_name = args.opt.lower() - if opt_name == 'sgd': + if opt_name == "sgd": optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - elif opt_name == 'rmsprop': - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + elif opt_name == "rmsprop": + optimizer = torch.optim.RMSprop( + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + eps=0.0316, + alpha=0.9, + ) else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.apex_opt_level - ) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) @@ -187,11 +200,11 @@ def main(args): model_without_ddp = model.module if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -207,49 +220,51 @@ def main(args): evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') - parser.add_argument('--model', default='resnet18', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=32, type=int) - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--opt', default='sgd', type=str, help='optimizer') - parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') - parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset") + parser.add_argument("--model", default="resnet18", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("-b", "--batch-size", default=32, type=int) + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -274,22 +289,23 @@ def get_args_parser(add_help=True): help="Use pre-trained models from the modelzoo", action="store_true", ) - parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') - parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') + parser.add_argument("--auto-augment", default=None, help="auto augment policy (default: None)") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") # Mixed precision training parameters - parser.add_argument('--apex', action='store_true', - help='Use apex for mixed precision training') - parser.add_argument('--apex-opt-level', default='O1', type=str, - help='For apex mixed precision training' - 'O0 for FP32 training, O1 for mixed precision training.' - 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' - ) + parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") + parser.add_argument( + "--apex-opt-level", + default="O1", + type=str, + help="For apex mixed precision training" + "O0 for FP32 training, O1 for mixed precision training." + "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", + ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index ec945f4f58f..e37c7587130 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -1,15 +1,16 @@ +import copy import datetime import os import time -import copy + +import utils +from train import evaluate, load_data, train_one_epoch import torch +import torch.quantization import torch.utils.data -from torch import nn import torchvision -import torch.quantization -import utils -from train import train_one_epoch, evaluate, load_data +from torch import nn def main(args): @@ -20,8 +21,7 @@ def main(args): print(args) if args.post_training_quantize and args.distributed: - raise RuntimeError("Post training quantization example should not be performed " - "on distributed mode") + raise RuntimeError("Post training quantization example should not be performed " "on distributed mode") # Set backend engine to ensure that quantized model runs on the correct kernels if args.backend not in torch.backends.quantized.supported_engines: @@ -33,17 +33,17 @@ def main(args): # Data loading code print("Loading data") - train_dir = os.path.join(args.data_path, 'train') - val_dir = os.path.join(args.data_path, 'val') + train_dir = os.path.join(args.data_path, "train") + val_dir = os.path.join(args.data_path, "val") dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) + dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.eval_batch_size, - sampler=test_sampler, num_workers=args.workers, pin_memory=True) + dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + ) print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model @@ -59,12 +59,10 @@ def main(args): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay) + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, - step_size=args.lr_step_size, - gamma=args.lr_gamma) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) criterion = nn.CrossEntropyLoss() model_without_ddp = model @@ -73,21 +71,19 @@ def main(args): model_without_ddp = model.module if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.post_training_quantize: # perform calibration on a subset of the training dataset # for that, create a subset of the training dataset - ds = torch.utils.data.Subset( - dataset, - indices=list(range(args.batch_size * args.num_calibration_batches))) + ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches))) data_loader_calibration = torch.utils.data.DataLoader( - ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, - pin_memory=True) + ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) model.eval() model.fuse_model() model.qconfig = torch.quantization.get_default_qconfig(args.backend) @@ -97,10 +93,9 @@ def main(args): evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) torch.quantization.convert(model, inplace=True) if args.output_dir: - print('Saving quantized model') + print("Saving quantized model") if utils.is_main_process(): - torch.save(model.state_dict(), os.path.join(args.output_dir, - 'quantized_post_train_model.pth')) + torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth")) print("Evaluating post-training quantized model") evaluate(model, criterion, data_loader_test, device=device) return @@ -115,107 +110,103 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - print('Starting training for epoch', epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, - args.print_freq) + print("Starting training for epoch", epoch) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() with torch.no_grad(): if epoch >= args.num_observer_update_epochs: - print('Disabling observer for subseq epochs, epoch = ', epoch) + print("Disabling observer for subseq epochs, epoch = ", epoch) model.apply(torch.quantization.disable_observer) if epoch >= args.num_batch_norm_update_epochs: - print('Freezing BN for subseq epochs, epoch = ', epoch) + print("Freezing BN for subseq epochs, epoch = ", epoch) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) - print('Evaluate QAT model') + print("Evaluate QAT model") evaluate(model, criterion, data_loader_test, device=device) quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() - quantized_eval_model.to(torch.device('cpu')) + quantized_eval_model.to(torch.device("cpu")) torch.quantization.convert(quantized_eval_model, inplace=True) - print('Evaluate Quantized model') - evaluate(quantized_eval_model, criterion, data_loader_test, - device=torch.device('cpu')) + print("Evaluate Quantized model") + evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) model.train() if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'eval_model': quantized_eval_model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) - print('Saving models after epoch ', epoch) + "model": model_without_ddp.state_dict(), + "eval_model": quantized_eval_model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) + print("Saving models after epoch ", epoch) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help) - - parser.add_argument('--data-path', - default='/datasets01/imagenet_full_size/061417/', - help='dataset') - parser.add_argument('--model', - default='mobilenet_v2', - help='model') - parser.add_argument('--backend', - default='qnnpack', - help='fbgemm or qnnpack') - parser.add_argument('--device', - default='cuda', - help='device') - - parser.add_argument('-b', '--batch-size', default=32, type=int, - help='batch size for calibration/training') - parser.add_argument('--eval-batch-size', default=128, type=int, - help='batch size for evaluation') - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('--num-observer-update-epochs', - default=4, type=int, metavar='N', - help='number of total epochs to update observers') - parser.add_argument('--num-batch-norm-update-epochs', default=3, - type=int, metavar='N', - help='number of total epochs to update batch norm stats') - parser.add_argument('--num-calibration-batches', - default=32, type=int, metavar='N', - help='number of batches of training set for \ - observer calibration ') - - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--lr', - default=0.0001, type=float, - help='initial learning rate') - parser.add_argument('--momentum', - default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-step-size', default=30, type=int, - help='decrease lr every step-size epochs') - parser.add_argument('--lr-gamma', default=0.1, type=float, - help='decrease lr by a factor of lr-gamma') - parser.add_argument('--print-freq', default=10, type=int, - help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset") + parser.add_argument("--model", default="mobilenet_v2", help="model") + parser.add_argument("--backend", default="qnnpack", help="fbgemm or qnnpack") + parser.add_argument("--device", default="cuda", help="device") + + parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size for calibration/training") + parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation") + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "--num-observer-update-epochs", + default=4, + type=int, + metavar="N", + help="number of total epochs to update observers", + ) + parser.add_argument( + "--num-batch-norm-update-epochs", + default=3, + type=int, + metavar="N", + help="number of total epochs to update batch norm stats", + ) + parser.add_argument( + "--num-calibration-batches", + default=32, + type=int, + metavar="N", + help="number of batches of training set for \ + observer calibration ", + ) + + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -243,11 +234,8 @@ def get_args_parser(add_help=True): ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', - default='env://', - help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 4e53ed1d3d7..4a2605e14ea 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -1,14 +1,14 @@ -from collections import defaultdict, deque, OrderedDict import copy import datetime +import errno import hashlib +import os import time +from collections import OrderedDict, defaultdict, deque + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -34,7 +34,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -65,11 +65,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class MetricLogger(object): @@ -89,15 +86,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -110,31 +104,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -144,21 +135,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) def accuracy(output, target, topk=(1,)): @@ -191,10 +189,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -231,28 +230,28 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) @@ -275,9 +274,7 @@ def average_checkpoints(inputs): with open(fpath, "rb") as f: state = torch.load( f, - map_location=( - lambda s, _: torch.serialization.default_restore_location(s, "cpu") - ), + map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), ) # Copies over the settings from the first checkpoint if new_state is None: @@ -311,7 +308,7 @@ def average_checkpoints(inputs): return new_state -def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True): +def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): """ This method can be used to prepare weights files for new models. It receives as input a model architecture and a checkpoint from the training script and produces @@ -357,7 +354,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T # Deep copy to avoid side-effects on the model object. model = copy.deepcopy(model) - checkpoint = torch.load(checkpoint_path, map_location='cpu') + checkpoint = torch.load(checkpoint_path, map_location="cpu") # Load the weights to the model to validate that everything works # and remove unnecessary weights (such as auxiliaries, etc) diff --git a/references/detection/coco_eval.py b/references/detection/coco_eval.py index 09648f29ae4..40831bedac6 100644 --- a/references/detection/coco_eval.py +++ b/references/detection/coco_eval.py @@ -1,19 +1,15 @@ +import copy import json -import tempfile +from collections import defaultdict import numpy as np -import copy -import time -import torch -import torch._six - -from pycocotools.cocoeval import COCOeval -from pycocotools.coco import COCO import pycocotools.mask as mask_util - -from collections import defaultdict - import utils +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +import torch +import torch._six class CocoEvaluator(object): @@ -109,8 +105,7 @@ def prepare_for_coco_segmentation(self, predictions): labels = prediction["labels"].tolist() rles = [ - mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] - for mask in masks + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks ] for rle in rles: rle["counts"] = rle["counts"].decode("utf-8") @@ -146,7 +141,7 @@ def prepare_for_coco_keypoint(self, predictions): { "image_id": original_id, "category_id": labels[k], - 'keypoints': keypoint, + "keypoints": keypoint, "score": scores[k], } for k, keypoint in enumerate(keypoints) @@ -200,27 +195,28 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs): # Ideally, pycocotools wouldn't have hard-coded prints # so that we could avoid copy-pasting those two functions + def createIndex(self): # create index # print('creating index...') anns, cats, imgs = {}, {}, {} imgToAnns, catToImgs = defaultdict(list), defaultdict(list) - if 'annotations' in self.dataset: - for ann in self.dataset['annotations']: - imgToAnns[ann['image_id']].append(ann) - anns[ann['id']] = ann + if "annotations" in self.dataset: + for ann in self.dataset["annotations"]: + imgToAnns[ann["image_id"]].append(ann) + anns[ann["id"]] = ann - if 'images' in self.dataset: - for img in self.dataset['images']: - imgs[img['id']] = img + if "images" in self.dataset: + for img in self.dataset["images"]: + imgs[img["id"]] = img - if 'categories' in self.dataset: - for cat in self.dataset['categories']: - cats[cat['id']] = cat + if "categories" in self.dataset: + for cat in self.dataset["categories"]: + cats[cat["id"]] = cat - if 'annotations' in self.dataset and 'categories' in self.dataset: - for ann in self.dataset['annotations']: - catToImgs[ann['category_id']].append(ann['image_id']) + if "annotations" in self.dataset and "categories" in self.dataset: + for ann in self.dataset["annotations"]: + catToImgs[ann["category_id"]].append(ann["image_id"]) # print('index created!') @@ -245,7 +241,7 @@ def loadRes(self, resFile): res (obj): result api object """ res = COCO() - res.dataset['images'] = [img for img in self.dataset['images']] + res.dataset["images"] = [img for img in self.dataset["images"]] # print('Loading and preparing results...') # tic = time.time() @@ -255,63 +251,62 @@ def loadRes(self, resFile): anns = self.loadNumpyAnnotations(resFile) else: anns = resFile - assert type(anns) == list, 'results in not an array of objects' - annsImgIds = [ann['image_id'] for ann in anns] - assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ - 'Results do not correspond to current coco set' - if 'caption' in anns[0]: - imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) - res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] + assert type(anns) == list, "results in not an array of objects" + annsImgIds = [ann["image_id"] for ann in anns] + assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), "Results do not correspond to current coco set" + if "caption" in anns[0]: + imgIds = set([img["id"] for img in res.dataset["images"]]) & set([ann["image_id"] for ann in anns]) + res.dataset["images"] = [img for img in res.dataset["images"] if img["id"] in imgIds] for id, ann in enumerate(anns): - ann['id'] = id + 1 - elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: - res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + ann["id"] = id + 1 + elif "bbox" in anns[0] and not anns[0]["bbox"] == []: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) for id, ann in enumerate(anns): - bb = ann['bbox'] + bb = ann["bbox"] x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] - if 'segmentation' not in ann: - ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] - ann['area'] = bb[2] * bb[3] - ann['id'] = id + 1 - ann['iscrowd'] = 0 - elif 'segmentation' in anns[0]: - res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + if "segmentation" not in ann: + ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann["area"] = bb[2] * bb[3] + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "segmentation" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) for id, ann in enumerate(anns): # now only support compressed RLE format as segmentation results - ann['area'] = maskUtils.area(ann['segmentation']) - if 'bbox' not in ann: - ann['bbox'] = maskUtils.toBbox(ann['segmentation']) - ann['id'] = id + 1 - ann['iscrowd'] = 0 - elif 'keypoints' in anns[0]: - res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + ann["area"] = maskUtils.area(ann["segmentation"]) + if "bbox" not in ann: + ann["bbox"] = maskUtils.toBbox(ann["segmentation"]) + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "keypoints" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) for id, ann in enumerate(anns): - s = ann['keypoints'] + s = ann["keypoints"] x = s[0::3] y = s[1::3] x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y) - ann['area'] = (x2 - x1) * (y2 - y1) - ann['id'] = id + 1 - ann['bbox'] = [x1, y1, x2 - x1, y2 - y1] + ann["area"] = (x2 - x1) * (y2 - y1) + ann["id"] = id + 1 + ann["bbox"] = [x1, y1, x2 - x1, y2 - y1] # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) - res.dataset['annotations'] = anns + res.dataset["annotations"] = anns createIndex(res) return res def evaluate(self): - ''' + """ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs :return: None - ''' + """ # tic = time.time() # print('Running per image evaluation...') p = self.params # add backward compatibility if useSegm is specified in params if p.useSegm is not None: - p.iouType = 'segm' if p.useSegm == 1 else 'bbox' - print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + p.iouType = "segm" if p.useSegm == 1 else "bbox" + print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) # print('Evaluate annotation type *{}*'.format(p.iouType)) p.imgIds = list(np.unique(p.imgIds)) if p.useCats: @@ -323,22 +318,16 @@ def evaluate(self): # loop through images, area range, max detection number catIds = p.catIds if p.useCats else [-1] - if p.iouType == 'segm' or p.iouType == 'bbox': + if p.iouType == "segm" or p.iouType == "bbox": computeIoU = self.computeIoU - elif p.iouType == 'keypoints': + elif p.iouType == "keypoints": computeIoU = self.computeOks - self.ious = { - (imgId, catId): computeIoU(imgId, catId) - for imgId in p.imgIds - for catId in catIds} + self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds} evaluateImg = self.evaluateImg maxDet = p.maxDets[-1] evalImgs = [ - evaluateImg(imgId, catId, areaRng, maxDet) - for catId in catIds - for areaRng in p.areaRng - for imgId in p.imgIds + evaluateImg(imgId, catId, areaRng, maxDet) for catId in catIds for areaRng in p.areaRng for imgId in p.imgIds ] # this is NOT in the pycocotools code, but could be done outside evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) @@ -347,6 +336,7 @@ def evaluate(self): # print('DONE (t={:0.2f}s).'.format(toc-tic)) return p.imgIds, evalImgs + ################################################################# # end of straight copy from pycocotools, just removing the prints ################################################################# diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 26701a2cbee..c7390d06cc6 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,15 +1,13 @@ import copy import os -from PIL import Image - -import torch -import torch.utils.data -import torchvision +import transforms as T from pycocotools import mask as coco_mask from pycocotools.coco import COCO -import transforms as T +import torch +import torch.utils.data +import torchvision class FilterAndRemapCocoCategories(object): @@ -56,7 +54,7 @@ def __call__(self, image, target): anno = target["annotations"] - anno = [obj for obj in anno if obj['iscrowd'] == 0] + anno = [obj for obj in anno if obj["iscrowd"] == 0] boxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing @@ -147,7 +145,7 @@ def convert_to_coco_api(ds): coco_ds = COCO() # annotation IDs need to start at 1, not 0, see torchvision issue #1530 ann_id = 1 - dataset = {'images': [], 'categories': [], 'annotations': []} + dataset = {"images": [], "categories": [], "annotations": []} categories = set() for img_idx in range(len(ds)): # find better way to get target @@ -155,41 +153,41 @@ def convert_to_coco_api(ds): img, targets = ds[img_idx] image_id = targets["image_id"].item() img_dict = {} - img_dict['id'] = image_id - img_dict['height'] = img.shape[-2] - img_dict['width'] = img.shape[-1] - dataset['images'].append(img_dict) + img_dict["id"] = image_id + img_dict["height"] = img.shape[-2] + img_dict["width"] = img.shape[-1] + dataset["images"].append(img_dict) bboxes = targets["boxes"] bboxes[:, 2:] -= bboxes[:, :2] bboxes = bboxes.tolist() - labels = targets['labels'].tolist() - areas = targets['area'].tolist() - iscrowd = targets['iscrowd'].tolist() - if 'masks' in targets: - masks = targets['masks'] + labels = targets["labels"].tolist() + areas = targets["area"].tolist() + iscrowd = targets["iscrowd"].tolist() + if "masks" in targets: + masks = targets["masks"] # make masks Fortran contiguous for coco_mask masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) - if 'keypoints' in targets: - keypoints = targets['keypoints'] + if "keypoints" in targets: + keypoints = targets["keypoints"] keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() num_objs = len(bboxes) for i in range(num_objs): ann = {} - ann['image_id'] = image_id - ann['bbox'] = bboxes[i] - ann['category_id'] = labels[i] + ann["image_id"] = image_id + ann["bbox"] = bboxes[i] + ann["category_id"] = labels[i] categories.add(labels[i]) - ann['area'] = areas[i] - ann['iscrowd'] = iscrowd[i] - ann['id'] = ann_id - if 'masks' in targets: + ann["area"] = areas[i] + ann["iscrowd"] = iscrowd[i] + ann["id"] = ann_id + if "masks" in targets: ann["segmentation"] = coco_mask.encode(masks[i].numpy()) - if 'keypoints' in targets: - ann['keypoints'] = keypoints[i] - ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) - dataset['annotations'].append(ann) + if "keypoints" in targets: + ann["keypoints"] = keypoints[i] + ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) + dataset["annotations"].append(ann) ann_id += 1 - dataset['categories'] = [{'id': i} for i in sorted(categories)] + dataset["categories"] = [{"id": i} for i in sorted(categories)] coco_ds.dataset = dataset coco_ds.createIndex() return coco_ds @@ -220,7 +218,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode='instances'): +def get_coco(root, image_set, transforms, mode="instances"): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), diff --git a/references/detection/engine.py b/references/detection/engine.py index 49992af60a9..c15a3c1d2a0 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -1,24 +1,24 @@ import math import sys import time -import torch - -import torchvision.models.detection.mask_rcnn -from coco_utils import get_coco_api_from_dataset -from coco_eval import CocoEvaluator import utils +from coco_eval import CocoEvaluator +from coco_utils import get_coco_api_from_dataset + +import torch +import torchvision.models.detection.mask_rcnn def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) - header = 'Epoch: [{}]'.format(epoch) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) lr_scheduler = None if epoch == 0: - warmup_factor = 1. / 1000 + warmup_factor = 1.0 / 1000 warmup_iters = min(1000, len(data_loader) - 1) lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) @@ -75,7 +75,7 @@ def evaluate(model, data_loader, device): cpu_device = torch.device("cpu") model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" coco = get_coco_api_from_dataset(data_loader.dataset) iou_types = _get_iou_types(model) diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index 1b76f4c64f7..a870e0a40f2 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -1,17 +1,17 @@ import bisect -from collections import defaultdict import copy -from itertools import repeat, chain import math +from collections import defaultdict +from itertools import chain, repeat + import numpy as np +from PIL import Image import torch import torch.utils.data +import torchvision from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.model_zoo import tqdm -import torchvision - -from PIL import Image def _repeat_to_at_least(iterable, n): @@ -34,11 +34,11 @@ class GroupedBatchSampler(BatchSampler): 0, i.e. they must be in the range [0, num_groups). batch_size (int): Size of mini-batch. """ + def __init__(self, sampler, group_ids, batch_size): if not isinstance(sampler, Sampler): raise ValueError( - "sampler should be an instance of " - "torch.utils.data.Sampler, but got sampler={}".format(sampler) + "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler) ) self.sampler = sampler self.group_ids = group_ids @@ -68,8 +68,7 @@ def __iter__(self): if num_remaining > 0: # for the remaining batches, take first the buffers with largest number # of elements - for group_id, _ in sorted(buffer_per_group.items(), - key=lambda x: len(x[1]), reverse=True): + for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True): remaining = self.batch_size - len(buffer_per_group[group_id]) samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining) buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) @@ -85,10 +84,12 @@ def __len__(self): def _compute_aspect_ratios_slow(dataset, indices=None): - print("Your dataset doesn't support the fast path for " - "computing the aspect ratios, so will iterate over " - "the full dataset and load every image instead. " - "This might take some time...") + print( + "Your dataset doesn't support the fast path for " + "computing the aspect ratios, so will iterate over " + "the full dataset and load every image instead. " + "This might take some time..." + ) if indices is None: indices = range(len(dataset)) @@ -104,9 +105,12 @@ def __len__(self): sampler = SubsetSampler(indices) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=1, sampler=sampler, + dataset, + batch_size=1, + sampler=sampler, num_workers=14, # you might want to increase it for faster processing - collate_fn=lambda x: x[0]) + collate_fn=lambda x: x[0], + ) aspect_ratios = [] with tqdm(total=len(dataset)) as pbar: for _i, (img, _) in enumerate(data_loader): diff --git a/references/detection/presets.py b/references/detection/presets.py index 1fac69ae356..f04e8352ad5 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -2,26 +2,32 @@ class DetectionPresetTrain: - def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): - if data_augmentation == 'hflip': - self.transforms = T.Compose([ - T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), - ]) - elif data_augmentation == 'ssd': - self.transforms = T.Compose([ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), - ]) - elif data_augmentation == 'ssdlite': - self.transforms = T.Compose([ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), - ]) + def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): + if data_augmentation == "hflip": + self.transforms = T.Compose( + [ + T.RandomHorizontalFlip(p=hflip_prob), + T.ToTensor(), + ] + ) + elif data_augmentation == "ssd": + self.transforms = T.Compose( + [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=list(mean)), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToTensor(), + ] + ) + elif data_augmentation == "ssdlite": + self.transforms = T.Compose( + [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToTensor(), + ] + ) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/references/detection/train.py b/references/detection/train.py index cd4148e9bf7..e83b1067465 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -21,26 +21,21 @@ import os import time +import presets +import utils +from coco_utils import get_coco, get_coco_kp +from engine import evaluate, train_one_epoch +from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups + import torch import torch.utils.data import torchvision import torchvision.models.detection import torchvision.models.detection.mask_rcnn -from coco_utils import get_coco, get_coco_kp - -from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups -from engine import train_one_epoch, evaluate - -import presets -import utils - def get_dataset(name, image_set, transform, data_path): - paths = { - "coco": (data_path, get_coco, 91), - "coco_kp": (data_path, get_coco_kp, 2) - } + paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] ds = ds_fn(p, image_set=image_set, transforms=transform) @@ -53,42 +48,60 @@ def get_transform(train, data_augmentation): def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') - parser.add_argument('--dataset', default='coco', help='dataset') - parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=2, type=int, - help='images per gpu, the total batch size is $NGPU x batch_size') - parser.add_argument('--epochs', default=26, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') - parser.add_argument('--lr', default=0.02, type=float, - help='initial learning rate, 0.02 is the default value for training ' - 'on 8 gpus and 2 images_per_gpu') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)') - parser.add_argument('--lr-step-size', default=8, type=int, - help='decrease lr every step-size epochs (multisteplr scheduler only)') - parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, - help='decrease lr every step-size epochs (multisteplr scheduler only)') - parser.add_argument('--lr-gamma', default=0.1, type=float, - help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)') - parser.add_argument('--print-freq', default=20, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') - parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) - parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn') - parser.add_argument('--trainable-backbone-layers', default=None, type=int, - help='number of trainable layers of backbone') - parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)') + + parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/COCO/022719/", help="dataset") + parser.add_argument("--dataset", default="coco", help="dataset") + parser.add_argument("--model", default="maskrcnn_resnet50_fpn", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument( + "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" + ) + parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" + ) + parser.add_argument( + "--lr", + default=0.02, + type=float, + help="initial learning rate, 0.02 is the default value for training " "on 8 gpus and 2 images_per_gpu", + ) + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-scheduler", default="multisteplr", help="the lr scheduler (default: multisteplr)") + parser.add_argument( + "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)" + ) + parser.add_argument( + "--lr-steps", + default=[16, 22], + nargs="+", + type=int, + help="decrease lr every step-size epochs (multisteplr scheduler only)", + ) + parser.add_argument( + "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)" + ) + parser.add_argument("--print-freq", default=20, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start_epoch", default=0, type=int, help="start epoch") + parser.add_argument("--aspect-ratio-group-factor", default=3, type=int) + parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn") + parser.add_argument( + "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone" + ) + parser.add_argument("--data-augmentation", default="hflip", help="data augmentation policy (default: hflip)") parser.add_argument( "--sync-bn", dest="sync_bn", @@ -109,9 +122,8 @@ def get_args_parser(add_help=True): ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser @@ -128,8 +140,9 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation), - args.data_path) + dataset, num_classes = get_dataset( + args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path + ) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path) print("Creating data loaders") @@ -144,27 +157,24 @@ def main(args): group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: - train_batch_sampler = torch.utils.data.BatchSampler( - train_sampler, args.batch_size, drop_last=True) + train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( - dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=1, - sampler=test_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) print("Creating model") - kwargs = { - "trainable_backbone_layers": args.trainable_backbone_layers - } + kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, - **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + num_classes=num_classes, pretrained=args.pretrained, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -175,24 +185,25 @@ def main(args): model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] - optimizer = torch.optim.SGD( - params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) args.lr_scheduler = args.lr_scheduler.lower() - if args.lr_scheduler == 'multisteplr': + if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - elif args.lr_scheduler == 'cosineannealinglr': + elif args.lr_scheduler == "cosineannealinglr": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) else: - raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " - "are supported.".format(args.lr_scheduler)) + raise RuntimeError( + "Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " + "are supported.".format(args.lr_scheduler) + ) if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: evaluate(model, data_loader_test, device=device) @@ -207,25 +218,21 @@ def main(args): lr_scheduler.step() if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'args': args, - 'epoch': epoch + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "args": args, + "epoch": epoch, } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) if __name__ == "__main__": diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 8e4b8870eaf..58ac03f2da3 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,10 +1,9 @@ +from typing import Dict, List, Optional, Tuple + import torch import torchvision - -from torch import nn, Tensor -from torchvision.transforms import functional as F -from torchvision.transforms import transforms as T -from typing import List, Tuple, Dict, Optional +from torch import Tensor, nn +from torchvision.transforms import functional as F, transforms as T def _flip_coco_person_keypoints(kps, width): @@ -28,8 +27,9 @@ def __call__(self, image, target): class RandomHorizontalFlip(T.RandomHorizontalFlip): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: @@ -45,15 +45,23 @@ def forward(self, image: Tensor, class ToTensor(nn.Module): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.to_tensor(image) return image, target class RandomIoUCrop(nn.Module): - def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, - max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): super().__init__() # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 self.min_scale = min_scale @@ -65,14 +73,15 @@ def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ra self.options = sampler_options self.trials = trials - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if target is None: raise ValueError("The targets can't be None for this transform.") if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -112,8 +121,9 @@ def forward(self, image: Tensor, # check at least 1 box with jaccard limitations boxes = target["boxes"][is_within_crop_area] - ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]], - dtype=boxes.dtype, device=boxes.device)) + ious = torchvision.ops.boxes.box_iou( + boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device) + ) if ious.max() < min_jaccard_overlap: continue @@ -130,13 +140,15 @@ def forward(self, image: Tensor, class RandomZoomOut(nn.Module): - def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5): + def __init__( + self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + ): super().__init__() if fill is None: - fill = [0., 0., 0.] + fill = [0.0, 0.0, 0.0] self.fill = fill self.side_range = side_range - if side_range[0] < 1. or side_range[0] > side_range[1]: + if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError("Invalid canvas side range provided {}.".format(side_range)) self.p = p @@ -146,11 +158,12 @@ def _get_fill_value(self, is_pil): # We fake the type to make it work on JIT return tuple(int(x) for x in self.fill) if is_pil else 0 - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -177,8 +190,9 @@ def forward(self, image: Tensor, image = F.pad(image, [left, top, right, bottom], fill=fill) if isinstance(image, torch.Tensor): v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) - image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \ - image[..., :, (left + orig_w):] = v + image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[ + ..., :, (left + orig_w) : + ] = v if target is not None: target["boxes"][:, 0::2] += left @@ -188,8 +202,14 @@ def forward(self, image: Tensor, class RandomPhotometricDistort(nn.Module): - def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5), - hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5): + def __init__( + self, + contrast: Tuple[float] = (0.5, 1.5), + saturation: Tuple[float] = (0.5, 1.5), + hue: Tuple[float] = (-0.05, 0.05), + brightness: Tuple[float] = (0.875, 1.125), + p: float = 0.5, + ): super().__init__() self._brightness = T.ColorJitter(brightness=brightness) self._contrast = T.ColorJitter(contrast=contrast) @@ -197,11 +217,12 @@ def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] self._saturation = T.ColorJitter(saturation=saturation) self.p = p - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) diff --git a/references/detection/utils.py b/references/detection/utils.py index 3c52abb2167..13bb90f4079 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -1,8 +1,8 @@ -from collections import defaultdict, deque import datetime import errno import os import time +from collections import defaultdict, deque import torch import torch.distributed as dist @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) def all_gather(data): @@ -130,15 +127,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -151,31 +145,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -185,22 +176,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) + print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) def collate_fn(batch): @@ -208,7 +205,6 @@ def collate_fn(batch): def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - def f(x): if x >= warmup_iters: return 1 @@ -231,10 +227,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -271,25 +268,25 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index c86d5495247..c21fa84826c 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -1,15 +1,14 @@ import copy -import torch -import torch.utils.data -import torchvision -from PIL import Image - import os +from PIL import Image from pycocotools import mask as coco_mask - from transforms import Compose +import torch +import torch.utils.data +import torchvision + class FilterAndRemapCocoCategories(object): def __init__(self, categories, remap=True): @@ -90,14 +89,9 @@ def get_coco(root, image_set, transforms): "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) } - CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, - 1, 64, 20, 63, 7, 72] - - transforms = Compose([ - FilterAndRemapCocoCategories(CAT_LIST, remap=True), - ConvertCocoPolysToMask(), - transforms - ]) + CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] + + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 3bf29c23751..b7838e44cb0 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -9,11 +9,13 @@ def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.4 trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend([ - T.RandomCrop(crop_size), - T.ToTensor(), - T.Normalize(mean=mean, std=std), - ]) + trans.extend( + [ + T.RandomCrop(crop_size), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ] + ) self.transforms = T.Compose(trans) def __call__(self, img, target): @@ -22,11 +24,13 @@ def __call__(self, img, target): class SegmentationPresetEval: def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose([ - T.RandomResize(base_size, base_size), - T.ToTensor(), - T.Normalize(mean=mean, std=std), - ]) + self.transforms = T.Compose( + [ + T.RandomResize(base_size, base_size), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ] + ) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index fb6c7eeee15..2e1fe6357c6 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -2,23 +2,24 @@ import os import time +import presets +import utils +from coco_utils import get_coco + import torch import torch.utils.data -from torch import nn import torchvision - -from coco_utils import get_coco -import presets -import utils +from torch import nn def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): - return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) + return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + paths = { "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), "voc_aug": (dir_path, sbd, 21), - "coco": (dir_path, get_coco, 21) + "coco": (dir_path, get_coco, 21), } p, ds_fn, num_classes = paths[name] @@ -39,21 +40,21 @@ def criterion(inputs, target): losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) if len(losses) == 1: - return losses['out'] + return losses["out"] - return losses['out'] + 0.5 * losses['aux'] + return losses["out"] + 0.5 * losses["aux"] def evaluate(model, data_loader, device, num_classes): model.eval() confmat = utils.ConfusionMatrix(num_classes) metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, 100, header): image, target = image.to(device), target.to(device) output = model(image) - output = output['out'] + output = output["out"] confmat.update(target.flatten(), output.argmax(1).flatten()) @@ -65,8 +66,8 @@ def evaluate(model, data_loader, device, num_classes): def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - header = 'Epoch: [{}]'.format(epoch) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + header = "Epoch: [{}]".format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): image, target = image.to(device), target.to(device) output = model(image) @@ -101,18 +102,21 @@ def main(args): test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn, drop_last=True) + dataset, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.workers, + collate_fn=utils.collate_fn, + drop_last=True, + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=1, - sampler=test_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) - model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, - aux_loss=args.aux_loss, - pretrained=args.pretrained) + model = torchvision.models.segmentation.__dict__[args.model]( + num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -129,21 +133,19 @@ def main(args): if args.aux_loss: params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr * 10}) - optimizer = torch.optim.SGD( - params_to_optimize, - lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, - lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9 + ) if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only) if not args.test_only: - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) @@ -158,50 +160,51 @@ def main(args): confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Segmentation Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path') - parser.add_argument('--dataset', default='coco', help='dataset name') - parser.add_argument('--model', default='fcn_resnet101', help='model') - parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=8, type=int) - parser.add_argument('--epochs', default=30, type=int, metavar='N', - help='number of total epochs to run') - - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/COCO/022719/", help="dataset path") + parser.add_argument("--dataset", default="coco", help="dataset name") + parser.add_argument("--model", default="fcn_resnet101", help="model") + parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("-b", "--batch-size", default=8, type=int) + parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run") + + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--test-only", dest="test_only", @@ -215,9 +218,8 @@ def get_args_parser(add_help=True): action="store_true", ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 4fe5a5ad147..28c08fd97b7 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,6 +1,7 @@ +import random + import numpy as np from PIL import Image -import random import torch from torchvision import transforms as T diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index b67c18052fb..2bb5451289a 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -1,12 +1,12 @@ -from collections import defaultdict, deque import datetime +import errno +import os import time +from collections import defaultdict, deque + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class ConfusionMatrix(object): @@ -82,7 +79,7 @@ def update(self, a, b): with torch.no_grad(): k = (a >= 0) & (a < n) inds = n * a[k].to(torch.int64) + b[k] - self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) def reset(self): self.mat.zero_() @@ -104,15 +101,12 @@ def reduce_from_all_processes(self): def __str__(self): acc_global, acc, iu = self.compute() - return ( - 'global correct: {:.1f}\n' - 'average row correct: {}\n' - 'IoU: {}\n' - 'mean IoU: {:.1f}').format( - acc_global.item() * 100, - ['{:.1f}'.format(i) for i in (acc * 100).tolist()], - ['{:.1f}'.format(i) for i in (iu * 100).tolist()], - iu.mean().item() * 100) + return ("global correct: {:.1f}\n" "average row correct: {}\n" "IoU: {}\n" "mean IoU: {:.1f}").format( + acc_global.item() * 100, + ["{:.1f}".format(i) for i in (acc * 100).tolist()], + ["{:.1f}".format(i) for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ) class MetricLogger(object): @@ -132,15 +126,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -153,31 +144,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -187,21 +175,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) def cat_list(images, fill_value=0): @@ -209,7 +204,7 @@ def cat_list(images, fill_value=0): batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).fill_(fill_value) for img, pad_img in zip(images, batched_imgs): - pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) return batched_imgs @@ -233,10 +228,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -273,26 +269,26 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) diff --git a/references/similarity/loss.py b/references/similarity/loss.py index 1fa4a89c762..237ad8e9e11 100644 --- a/references/similarity/loss.py +++ b/references/similarity/loss.py @@ -1,21 +1,21 @@ -''' +""" Pytorch adaptation of https://omoindrot.github.io/triplet-loss https://github.com/omoindrot/tensorflow-triplet-loss -''' +""" import torch import torch.nn as nn class TripletMarginLoss(nn.Module): - def __init__(self, margin=1.0, p=2., mining='batch_all'): + def __init__(self, margin=1.0, p=2.0, mining="batch_all"): super(TripletMarginLoss, self).__init__() self.margin = margin self.p = p self.mining = mining - if mining == 'batch_all': + if mining == "batch_all": self.loss_fn = batch_all_triplet_loss - if mining == 'batch_hard': + if mining == "batch_hard": self.loss_fn = batch_hard_triplet_loss def forward(self, embeddings, labels): diff --git a/references/similarity/sampler.py b/references/similarity/sampler.py index 0ae6d07a77c..591155fb449 100644 --- a/references/similarity/sampler.py +++ b/references/similarity/sampler.py @@ -1,7 +1,8 @@ +import random +from collections import defaultdict + import torch from torch.utils.data.sampler import Sampler -from collections import defaultdict -import random def create_groups(groups, k): diff --git a/references/similarity/test.py b/references/similarity/test.py index 8381e02e740..1618d4806de 100644 --- a/references/similarity/test.py +++ b/references/similarity/test.py @@ -1,15 +1,14 @@ import unittest from collections import defaultdict +from sampler import PKSampler + +import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import FakeData -import torchvision.transforms as transforms - -from sampler import PKSampler class Tester(unittest.TestCase): - def test_pksampler(self): p, k = 16, 4 @@ -19,8 +18,7 @@ def test_pksampler(self): self.assertRaises(AssertionError, PKSampler, targets, p, k) # Ensure p, k constraints on batch - dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), - transform=transforms.ToTensor()) + dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=transforms.ToTensor()) targets = [target.item() for _, target in dataset] sampler = PKSampler(targets, p, k) loader = DataLoader(dataset, batch_size=p * k, sampler=sampler) @@ -38,5 +36,5 @@ def test_pksampler(self): self.assertEqual(bins[b], k) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/references/similarity/train.py b/references/similarity/train.py index 9a166a14b38..198854cfb2d 100644 --- a/references/similarity/train.py +++ b/references/similarity/train.py @@ -1,16 +1,15 @@ import os +from loss import TripletMarginLoss +from model import EmbeddingNet +from sampler import PKSampler + import torch +import torchvision.transforms as transforms from torch.optim import Adam from torch.utils.data import DataLoader - -import torchvision.transforms as transforms from torchvision.datasets import FashionMNIST -from loss import TripletMarginLoss -from sampler import PKSampler -from model import EmbeddingNet - def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_freq): model.train() @@ -33,7 +32,7 @@ def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_f i += 1 avg_loss = running_loss / print_freq avg_trip = 100.0 * running_frac_pos_triplets / print_freq - print('[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%'.format(epoch, i, avg_loss, avg_trip)) + print("[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%".format(epoch, i, avg_loss, avg_trip)) running_loss = 0 running_frac_pos_triplets = 0 @@ -79,17 +78,17 @@ def evaluate(model, loader, device): threshold, accuracy = find_best_threshold(dists, targets, device) - print('accuracy: {:.3f}%, threshold: {:.2f}'.format(accuracy, threshold)) + print("accuracy: {:.3f}%, threshold: {:.2f}".format(accuracy, threshold)) def save(model, epoch, save_dir, file_name): - file_name = 'epoch_' + str(epoch) + '__' + file_name + file_name = "epoch_" + str(epoch) + "__" + file_name save_path = os.path.join(save_dir, file_name) torch.save(model.state_dict(), save_path) def main(args): - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") p = args.labels_per_batch k = args.samples_per_label batch_size = p * k @@ -103,9 +102,9 @@ def main(args): criterion = TripletMarginLoss(margin=args.margin) optimizer = Adam(model.parameters(), lr=args.lr) - transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')), - transforms.Resize((224, 224)), - transforms.ToTensor()]) + transform = transforms.Compose( + [transforms.Lambda(lambda image: image.convert("RGB")), transforms.Resize((224, 224)), transforms.ToTensor()] + ) # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can # be replaced with any classification dataset. @@ -118,48 +117,44 @@ def main(args): # targets attribute with the same format. targets = train_dataset.targets.tolist() - train_loader = DataLoader(train_dataset, batch_size=batch_size, - sampler=PKSampler(targets, p, k), - num_workers=args.workers) - test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, - shuffle=False, - num_workers=args.workers) + train_loader = DataLoader( + train_dataset, batch_size=batch_size, sampler=PKSampler(targets, p, k), num_workers=args.workers + ) + test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers) for epoch in range(1, args.epochs + 1): - print('Training...') + print("Training...") train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq) - print('Evaluating...') + print("Evaluating...") evaluate(model, test_loader, device) - print('Saving...') - save(model, epoch, args.save_dir, 'ckpt.pth') + print("Saving...") + save(model, epoch, args.save_dir, "ckpt.pth") def parse_args(): import argparse - parser = argparse.ArgumentParser(description='PyTorch Embedding Learning') - - parser.add_argument('--dataset-dir', default='/tmp/fmnist/', - help='FashionMNIST dataset directory path') - parser.add_argument('-p', '--labels-per-batch', default=8, type=int, - help='Number of unique labels/classes per batch') - parser.add_argument('-k', '--samples-per-label', default=8, type=int, - help='Number of samples per label in a batch') - parser.add_argument('--eval-batch-size', default=512, type=int) - parser.add_argument('--epochs', default=10, type=int, metavar='N', - help='Number of training epochs to run') - parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='Number of data loading workers') - parser.add_argument('--lr', default=0.0001, type=float, help='Learning rate') - parser.add_argument('--margin', default=0.2, type=float, help='Triplet loss margin') - parser.add_argument('--print-freq', default=20, type=int, help='Print frequency') - parser.add_argument('--save-dir', default='.', help='Model save directory') - parser.add_argument('--resume', default='', help='Resume from checkpoint') + + parser = argparse.ArgumentParser(description="PyTorch Embedding Learning") + + parser.add_argument("--dataset-dir", default="/tmp/fmnist/", help="FashionMNIST dataset directory path") + parser.add_argument( + "-p", "--labels-per-batch", default=8, type=int, help="Number of unique labels/classes per batch" + ) + parser.add_argument("-k", "--samples-per-label", default=8, type=int, help="Number of samples per label in a batch") + parser.add_argument("--eval-batch-size", default=512, type=int) + parser.add_argument("--epochs", default=10, type=int, metavar="N", help="Number of training epochs to run") + parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="Number of data loading workers") + parser.add_argument("--lr", default=0.0001, type=float, help="Learning rate") + parser.add_argument("--margin", default=0.2, type=float, help="Triplet loss margin") + parser.add_argument("--print-freq", default=20, type=int, help="Print frequency") + parser.add_argument("--save-dir", default=".", help="Model save directory") + parser.add_argument("--resume", default="", help="Resume from checkpoint") return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 3ee679ad5af..0a3152fed0d 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,12 +1,18 @@ -import torch +from transforms import ConvertBCHWtoCBHW, ConvertBHWCtoBCHW +import torch from torchvision.transforms import transforms -from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW class VideoClassificationPresetTrain: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), - hflip_prob=0.5): + def __init__( + self, + resize_size, + crop_size, + mean=(0.43216, 0.394666, 0.37645), + std=(0.22803, 0.22145, 0.216989), + hflip_prob=0.5, + ): trans = [ ConvertBHWCtoBCHW(), transforms.ConvertImageDtype(torch.float32), @@ -14,11 +20,7 @@ def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), st ] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - trans.extend([ - transforms.Normalize(mean=mean, std=std), - transforms.RandomCrop(crop_size), - ConvertBCHWtoCBHW() - ]) + trans.extend([transforms.Normalize(mean=mean, std=std), transforms.RandomCrop(crop_size), ConvertBCHWtoCBHW()]) self.transforms = transforms.Compose(trans) def __call__(self, x): @@ -27,14 +29,16 @@ def __call__(self, x): class VideoClassificationPresetEval: def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): - self.transforms = transforms.Compose([ - ConvertBHWCtoBCHW(), - transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), - transforms.Normalize(mean=mean, std=std), - transforms.CenterCrop(crop_size), - ConvertBCHWtoCBHW() - ]) + self.transforms = transforms.Compose( + [ + ConvertBHWCtoBCHW(), + transforms.ConvertImageDtype(torch.float32), + transforms.Resize(resize_size), + transforms.Normalize(mean=mean, std=std), + transforms.CenterCrop(crop_size), + ConvertBCHWtoCBHW(), + ] + ) def __call__(self, x): return self.transforms(x) diff --git a/references/video_classification/scheduler.py b/references/video_classification/scheduler.py index f0f862d41ad..4ec7060595f 100644 --- a/references/video_classification/scheduler.py +++ b/references/video_classification/scheduler.py @@ -1,6 +1,7 @@ -import torch from bisect import bisect_right +import torch + class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): def __init__( @@ -20,10 +21,7 @@ def __init__( ) if warmup_method not in ("constant", "linear"): - raise ValueError( - "Only 'constant' or 'linear' warmup_method accepted" - "got {}".format(warmup_method) - ) + raise ValueError("Only 'constant' or 'linear' warmup_method accepted" "got {}".format(warmup_method)) self.milestones = milestones self.gamma = gamma self.warmup_factor = warmup_factor @@ -40,8 +38,6 @@ def get_lr(self): alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ - base_lr * - warmup_factor * - self.gamma ** bisect_right(self.milestones, self.last_epoch) + base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs ] diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 11ac2d5378d..61553f5ebcb 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -1,19 +1,19 @@ import datetime import os import time -import torch -import torch.utils.data -from torch.utils.data.dataloader import default_collate -from torch import nn -import torchvision -import torchvision.datasets.video_utils -from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler import presets import utils - from scheduler import WarmupMultiStepLR +import torch +import torch.utils.data +import torchvision +import torchvision.datasets.video_utils +from torch import nn +from torch.utils.data.dataloader import default_collate +from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler + try: from apex import amp except ImportError: @@ -23,10 +23,10 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}")) - header = 'Epoch: [{}]'.format(epoch) + header = "Epoch: [{}]".format(epoch) for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) @@ -44,16 +44,16 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = video.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time)) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time)) lr_scheduler.step() def evaluate(model, criterion, data_loader, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" with torch.no_grad(): for video, target in metric_logger.log_every(data_loader, 100, header): video = video.to(device, non_blocking=True) @@ -66,18 +66,22 @@ def evaluate(model, criterion, data_loader, device): # could have been padded in distributed setup batch_size = video.shape[0] metric_logger.update(loss=loss.item()) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() - print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}' - .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + print( + " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format( + top1=metric_logger.acc1, top5=metric_logger.acc5 + ) + ) return metric_logger.acc1.global_avg def _get_cache_path(filepath): import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) @@ -92,8 +96,10 @@ def collate_fn(batch): def main(args): if args.apex and amp is None: - raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training.") + raise RuntimeError( + "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training." + ) if args.output_dir: utils.mkdir(args.output_dir) @@ -123,15 +129,17 @@ def main(args): dataset.transform = transform_train else: if args.distributed: - print("It is recommended to pre-compute the dataset cache " - "on a single-gpu first, as it will be faster") + print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") dataset = torchvision.datasets.Kinetics400( traindir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, frame_rate=15, - extensions=('avi', 'mp4', ) + extensions=( + "avi", + "mp4", + ), ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) @@ -151,15 +159,17 @@ def main(args): dataset_test.transform = transform_test else: if args.distributed: - print("It is recommended to pre-compute the dataset cache " - "on a single-gpu first, as it will be faster") + print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") dataset_test = torchvision.datasets.Kinetics400( valdir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_test, frame_rate=15, - extensions=('avi', 'mp4',) + extensions=( + "avi", + "mp4", + ), ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) @@ -174,14 +184,22 @@ def main(args): test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, - pin_memory=True, collate_fn=collate_fn) + dataset, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.workers, + pin_memory=True, + collate_fn=collate_fn, + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, - sampler=test_sampler, num_workers=args.workers, - pin_memory=True, collate_fn=collate_fn) + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, + collate_fn=collate_fn, + ) print("Creating model") model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) @@ -192,21 +210,18 @@ def main(args): criterion = nn.CrossEntropyLoss() lr = args.lr * args.world_size - optimizer = torch.optim.SGD( - model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.apex_opt_level - ) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs warmup_iters = args.lr_warmup_epochs * len(data_loader) lr_milestones = [len(data_loader) * m for m in args.lr_milestones] lr_scheduler = WarmupMultiStepLR( - optimizer, milestones=lr_milestones, gamma=args.lr_gamma, - warmup_iters=warmup_iters, warmup_factor=1e-5) + optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5 + ) model_without_ddp = model if args.distributed: @@ -214,11 +229,11 @@ def main(args): model_without_ddp = model.module if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -229,60 +244,63 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, - device, epoch, args.print_freq, args.apex) + train_one_epoch( + model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex + ) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def parse_args(): import argparse - parser = argparse.ArgumentParser(description='PyTorch Video Classification Training') - - parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') - parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') - parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') - parser.add_argument('--model', default='r2plus1d_18', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('--clip-len', default=16, type=int, metavar='N', - help='number of frames per clip') - parser.add_argument('--clips-per-video', default=5, type=int, metavar='N', - help='maximum number of clips per video to consider') - parser.add_argument('-b', '--batch-size', default=24, type=int) - parser.add_argument('--epochs', default=45, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', - help='number of data loading workers (default: 10)') - parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') - parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Video Classification Training") + + parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", help="dataset") + parser.add_argument("--train-dir", default="train_avi-480p", help="name of train dir") + parser.add_argument("--val-dir", default="val_avi-480p", help="name of val dir") + parser.add_argument("--model", default="r2plus1d_18", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip") + parser.add_argument( + "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider" + ) + parser.add_argument("-b", "--batch-size", default=24, type=int) + parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)" + ) + parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="number of warmup epochs") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -309,18 +327,19 @@ def parse_args(): ) # Mixed precision training parameters - parser.add_argument('--apex', action='store_true', - help='Use apex for mixed precision training') - parser.add_argument('--apex-opt-level', default='O1', type=str, - help='For apex mixed precision training' - 'O0 for FP32 training, O1 for mixed precision training.' - 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' - ) + parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") + parser.add_argument( + "--apex-opt-level", + default="O1", + type=str, + help="For apex mixed precision training" + "O0 for FP32 training, O1 for mixed precision training." + "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", + ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") args = parser.parse_args() diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 27f6c75450a..a0ce691bae7 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -3,16 +3,14 @@ class ConvertBHWCtoBCHW(nn.Module): - """Convert tensor from (B, H, W, C) to (B, C, H, W) - """ + """Convert tensor from (B, H, W, C) to (B, C, H, W)""" def forward(self, vid: torch.Tensor) -> torch.Tensor: return vid.permute(0, 3, 1, 2) class ConvertBCHWtoCBHW(nn.Module): - """Convert tensor from (B, C, H, W) to (C, B, H, W) - """ + """Convert tensor from (B, C, H, W) to (C, B, H, W)""" def forward(self, vid: torch.Tensor) -> torch.Tensor: return vid.permute(1, 0, 2, 3) diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 3573b84d780..956c4f85239 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -1,12 +1,12 @@ -from collections import defaultdict, deque import datetime +import errno +import os import time +from collections import defaultdict, deque + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class MetricLogger(object): @@ -87,15 +84,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -108,31 +102,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -142,21 +133,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) def accuracy(output, target, topk=(1,)): @@ -189,10 +187,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -229,26 +228,26 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) diff --git a/setup.cfg b/setup.cfg index fd3b74c47de..77981a6bcf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,39 @@ max-line-length = 120 [flake8] max-line-length = 120 -ignore = F401,E402,F403,W503,W504,F821 +ignore = E203, E402, W503, W504, F821 +per-file-ignores = + __init__.py: F401, F403, F405 + ./hubconf.py: F401 + torchvision/models/mobilenet.py: F401, F403 + torchvision/models/quantization/mobilenet.py: F401, F403 + test/smoke_test.py: F401 exclude = venv [pydocstyle] select = D417 # Missing argument descriptions in the docstring + +[isort] +# See link below for available options +# https://pycqa.github.io/isort/docs/configuration/options.html + +profile = black +line_length = 120 + +skip_gitignore = true +order_by_type = true +combine_as_imports = true +combine_star = true +float_to_top = true + +known_first_party = + torch, + torchvision +known_local_folder = + _utils_internal + common_utils + dataset_utils + + +skip = + gallery \ No newline at end of file diff --git a/setup.py b/setup.py index 195fdb2e7be..dcde2128314 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,22 @@ -import os -import io -import sys -from setuptools import setup, find_packages -from pkg_resources import parse_version, get_distribution, DistributionNotFound -import subprocess import distutils.command.clean import distutils.spawn -from distutils.version import StrictVersion import glob +import io +import os import shutil +import subprocess +import sys +from distutils.version import StrictVersion + +from pkg_resources import DistributionNotFound, get_distribution, parse_version +from setuptools import find_packages, setup import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension, CUDAExtension def read(*names, **kwargs): - with io.open( - os.path.join(os.path.dirname(__file__), *names), - encoding=kwargs.get("encoding", "utf8") - ) as fp: + with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() @@ -31,61 +29,61 @@ def get_dist(pkgname): cwd = os.path.dirname(os.path.abspath(__file__)) -version_txt = os.path.join(cwd, 'version.txt') -with open(version_txt, 'r') as f: +version_txt = os.path.join(cwd, "version.txt") +with open(version_txt, "r") as f: version = f.readline().strip() -sha = 'Unknown' -package_name = 'torchvision' +sha = "Unknown" +package_name = "torchvision" try: - sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() + sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip() except Exception: pass -if os.getenv('BUILD_VERSION'): - version = os.getenv('BUILD_VERSION') -elif sha != 'Unknown': - version += '+' + sha[:7] +if os.getenv("BUILD_VERSION"): + version = os.getenv("BUILD_VERSION") +elif sha != "Unknown": + version += "+" + sha[:7] def write_version_file(): - version_path = os.path.join(cwd, 'torchvision', 'version.py') - with open(version_path, 'w') as f: - f.write("__version__ = '{}'\n".format(version)) - f.write("git_version = {}\n".format(repr(sha))) - f.write("from torchvision.extension import _check_cuda_version\n") + version_path = os.path.join(cwd, "torchvision", "version.py") + with open(version_path, "w") as f: + f.write('__version__ = "{}"\n'.format(version)) + f.write('git_version = "{}"\n'.format(sha)) + f.write("from torchvision.extension import _check_cuda_version\n\n") f.write("if _check_cuda_version() > 0:\n") f.write(" cuda = _check_cuda_version()\n") -pytorch_dep = 'torch' -if os.getenv('PYTORCH_VERSION'): - pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') +pytorch_dep = "torch" +if os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") requirements = [ - 'numpy', + "numpy", pytorch_dep, ] # Excluding 8.3.0 because of https://github.com/pytorch/vision/issues/4146 -pillow_ver = ' >= 5.3.0, !=8.3.0' -pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' +pillow_ver = " >= 5.3.0, !=8.3.0" +pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow" requirements.append(pillow_req + pillow_ver) def find_library(name, vision_include): this_dir = os.path.dirname(os.path.abspath(__file__)) - build_prefix = os.environ.get('BUILD_PREFIX', None) + build_prefix = os.environ.get("BUILD_PREFIX", None) is_conda_build = build_prefix is not None library_found = False conda_installed = False lib_folder = None include_folder = None - library_header = '{0}.h'.format(name) + library_header = "{0}.h".format(name) # Lookup in TORCHVISION_INCLUDE or in the package file - package_path = [os.path.join(this_dir, 'torchvision')] + package_path = [os.path.join(this_dir, "torchvision")] for folder in vision_include + package_path: candidate_path = os.path.join(folder, library_header) library_found = os.path.exists(candidate_path) @@ -93,65 +91,64 @@ def find_library(name, vision_include): break if not library_found: - print('Running build on conda-build: {0}'.format(is_conda_build)) + print("Running build on conda-build: {0}".format(is_conda_build)) if is_conda_build: # Add conda headers/libraries - if os.name == 'nt': - build_prefix = os.path.join(build_prefix, 'Library') - include_folder = os.path.join(build_prefix, 'include') - lib_folder = os.path.join(build_prefix, 'lib') - library_header_path = os.path.join( - include_folder, library_header) + if os.name == "nt": + build_prefix = os.path.join(build_prefix, "Library") + include_folder = os.path.join(build_prefix, "include") + lib_folder = os.path.join(build_prefix, "lib") + library_header_path = os.path.join(include_folder, library_header) library_found = os.path.isfile(library_header_path) conda_installed = library_found else: # Check if using Anaconda to produce wheels - conda = distutils.spawn.find_executable('conda') + conda = distutils.spawn.find_executable("conda") is_conda = conda is not None - print('Running build on conda: {0}'.format(is_conda)) + print("Running build on conda: {0}".format(is_conda)) if is_conda: python_executable = sys.executable py_folder = os.path.dirname(python_executable) - if os.name == 'nt': - env_path = os.path.join(py_folder, 'Library') + if os.name == "nt": + env_path = os.path.join(py_folder, "Library") else: env_path = os.path.dirname(py_folder) - lib_folder = os.path.join(env_path, 'lib') - include_folder = os.path.join(env_path, 'include') - library_header_path = os.path.join( - include_folder, library_header) + lib_folder = os.path.join(env_path, "lib") + include_folder = os.path.join(env_path, "include") + library_header_path = os.path.join(include_folder, library_header) library_found = os.path.isfile(library_header_path) conda_installed = library_found if not library_found: - if sys.platform == 'linux': - library_found = os.path.exists('/usr/include/{0}'.format( - library_header)) - library_found = library_found or os.path.exists( - '/usr/local/include/{0}'.format(library_header)) + if sys.platform == "linux": + library_found = os.path.exists("/usr/include/{0}".format(library_header)) + library_found = library_found or os.path.exists("/usr/local/include/{0}".format(library_header)) return library_found, conda_installed, include_folder, lib_folder def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) - extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') + extensions_dir = os.path.join(this_dir, "torchvision", "csrc") - main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops', - '*.cpp')) + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob( + os.path.join(extensions_dir, "ops", "*.cpp") + ) source_cpu = ( - glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) + glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) is_rocm_pytorch = False - if torch.__version__ >= '1.5': + if torch.__version__ >= "1.5": from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False if is_rocm_pytorch: from torch.utils.hipify import hipify_python + hipify_python.hipify( project_directory=this_dir, output_directory=this_dir, @@ -159,25 +156,25 @@ def get_extensions(): show_detailed=True, is_pytorch_extension=True, ) - source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'hip', '*.hip')) + source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip")) # Copy over additional files for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"): shutil.copy(file, "torchvision/csrc/ops/hip") else: - source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu')) + source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp')) + source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) sources = main_file + source_cpu extension = CppExtension - compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1' + compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1" if compile_cpp_tests: - test_dir = os.path.join(this_dir, 'test') - models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') - test_file = glob.glob(os.path.join(test_dir, '*.cpp')) - source_models = glob.glob(os.path.join(models_dir, '*.cpp')) + test_dir = os.path.join(this_dir, "test") + models_dir = os.path.join(this_dir, "torchvision", "csrc", "models") + test_file = glob.glob(os.path.join(test_dir, "*.cpp")) + source_models = glob.glob(os.path.join(models_dir, "*.cpp")) test_file = [os.path.join(test_dir, s) for s in test_file] source_models = [os.path.join(models_dir, s) for s in source_models] @@ -186,39 +183,38 @@ def get_extensions(): define_macros = [] - extra_compile_args = {'cxx': []} - if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \ - or os.getenv('FORCE_CUDA', '0') == '1': + extra_compile_args = {"cxx": []} + if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv( + "FORCE_CUDA", "0" + ) == "1": extension = CUDAExtension sources += source_cuda if not is_rocm_pytorch: - define_macros += [('WITH_CUDA', None)] - nvcc_flags = os.getenv('NVCC_FLAGS', '') - if nvcc_flags == '': + define_macros += [("WITH_CUDA", None)] + nvcc_flags = os.getenv("NVCC_FLAGS", "") + if nvcc_flags == "": nvcc_flags = [] else: - nvcc_flags = nvcc_flags.split(' ') + nvcc_flags = nvcc_flags.split(" ") else: - define_macros += [('WITH_HIP', None)] + define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags - if sys.platform == 'win32': - define_macros += [('torchvision_EXPORTS', None)] + if sys.platform == "win32": + define_macros += [("torchvision_EXPORTS", None)] - extra_compile_args['cxx'].append('/MP') + extra_compile_args["cxx"].append("/MP") - debug_mode = os.getenv('DEBUG', '0') == '1' + debug_mode = os.getenv("DEBUG", "0") == "1" if debug_mode: print("Compile in debug mode") - extra_compile_args['cxx'].append("-g") - extra_compile_args['cxx'].append("-O0") + extra_compile_args["cxx"].append("-g") + extra_compile_args["cxx"].append("-O0") if "nvcc" in extra_compile_args: # we have to remove "-OX" and "-g" flag if exists and append nvcc_flags = extra_compile_args["nvcc"] - extra_compile_args["nvcc"] = [ - f for f in nvcc_flags if not ("-O" in f or "-g" in f) - ] + extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)] extra_compile_args["nvcc"].append("-O0") extra_compile_args["nvcc"].append("-g") @@ -228,7 +224,7 @@ def get_extensions(): ext_modules = [ extension( - 'torchvision._C', + "torchvision._C", sorted(sources), include_dirs=include_dirs, define_macros=define_macros, @@ -238,7 +234,7 @@ def get_extensions(): if compile_cpp_tests: ext_modules.append( extension( - 'torchvision._C_tests', + "torchvision._C_tests", tests, include_dirs=tests_include_dirs, define_macros=define_macros, @@ -247,12 +243,10 @@ def get_extensions(): ) # ------------------- Torchvision extra extensions ------------------------ - vision_include = os.environ.get('TORCHVISION_INCLUDE', None) - vision_library = os.environ.get('TORCHVISION_LIBRARY', None) - vision_include = (vision_include.split(os.pathsep) - if vision_include is not None else []) - vision_library = (vision_library.split(os.pathsep) - if vision_library is not None else []) + vision_include = os.environ.get("TORCHVISION_INCLUDE", None) + vision_library = os.environ.get("TORCHVISION_LIBRARY", None) + vision_include = vision_include.split(os.pathsep) if vision_include is not None else [] + vision_library = vision_library.split(os.pathsep) if vision_library is not None else [] include_dirs += vision_include library_dirs = vision_library @@ -263,56 +257,49 @@ def get_extensions(): image_link_flags = [] # Locating libPNG - libpng = distutils.spawn.find_executable('libpng-config') - pngfix = distutils.spawn.find_executable('pngfix') + libpng = distutils.spawn.find_executable("libpng-config") + pngfix = distutils.spawn.find_executable("pngfix") png_found = libpng is not None or pngfix is not None - print('PNG found: {0}'.format(png_found)) + print("PNG found: {0}".format(png_found)) if png_found: if libpng is not None: # Linux / Mac - png_version = subprocess.run([libpng, '--version'], - stdout=subprocess.PIPE) - png_version = png_version.stdout.strip().decode('utf-8') - print('libpng version: {0}'.format(png_version)) + png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE) + png_version = png_version.stdout.strip().decode("utf-8") + print("libpng version: {0}".format(png_version)) png_version = parse_version(png_version) if png_version >= parse_version("1.6.0"): - print('Building torchvision with PNG image support') - png_lib = subprocess.run([libpng, '--libdir'], - stdout=subprocess.PIPE) - png_lib = png_lib.stdout.strip().decode('utf-8') - if 'disabled' not in png_lib: + print("Building torchvision with PNG image support") + png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE) + png_lib = png_lib.stdout.strip().decode("utf-8") + if "disabled" not in png_lib: image_library += [png_lib] - png_include = subprocess.run([libpng, '--I_opts'], - stdout=subprocess.PIPE) - png_include = png_include.stdout.strip().decode('utf-8') - _, png_include = png_include.split('-I') - print('libpng include path: {0}'.format(png_include)) + png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE) + png_include = png_include.stdout.strip().decode("utf-8") + _, png_include = png_include.split("-I") + print("libpng include path: {0}".format(png_include)) image_include += [png_include] - image_link_flags.append('png') + image_link_flags.append("png") else: - print('libpng installed version is less than 1.6.0, ' - 'disabling PNG support') + print("libpng installed version is less than 1.6.0, " "disabling PNG support") png_found = False else: # Windows - png_lib = os.path.join( - os.path.dirname(os.path.dirname(pngfix)), 'lib') - png_include = os.path.join(os.path.dirname( - os.path.dirname(pngfix)), 'include', 'libpng16') + png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib") + png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16") image_library += [png_lib] image_include += [png_include] - image_link_flags.append('libpng') + image_link_flags.append("libpng") # Locating libjpeg - (jpeg_found, jpeg_conda, - jpeg_include, jpeg_lib) = find_library('jpeglib', vision_include) + (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include) - print('JPEG found: {0}'.format(jpeg_found)) - image_macros += [('PNG_FOUND', str(int(png_found)))] - image_macros += [('JPEG_FOUND', str(int(jpeg_found)))] + print("JPEG found: {0}".format(jpeg_found)) + image_macros += [("PNG_FOUND", str(int(png_found)))] + image_macros += [("JPEG_FOUND", str(int(jpeg_found)))] if jpeg_found: - print('Building torchvision with JPEG image support') - image_link_flags.append('jpeg') + print("Building torchvision with JPEG image support") + image_link_flags.append("jpeg") if jpeg_conda: image_library += [jpeg_lib] image_include += [jpeg_include] @@ -320,78 +307,70 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI nvjpeg_found = ( - extension is CUDAExtension and - CUDA_HOME is not None and - os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + extension is CUDAExtension + and CUDA_HOME is not None + and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h")) ) - print('NVJPEG found: {0}'.format(nvjpeg_found)) - image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + print("NVJPEG found: {0}".format(nvjpeg_found)) + image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))] if nvjpeg_found: - print('Building torchvision with NVJPEG image support') - image_link_flags.append('nvjpeg') - - image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) - + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) + print("Building torchvision with NVJPEG image support") + image_link_flags.append("nvjpeg") + + image_path = os.path.join(extensions_dir, "io", "image") + image_src = ( + glob.glob(os.path.join(image_path, "*.cpp")) + + glob.glob(os.path.join(image_path, "cpu", "*.cpp")) + + glob.glob(os.path.join(image_path, "cuda", "*.cpp")) + ) if png_found or jpeg_found: - ext_modules.append(extension( - 'torchvision.image', - image_src, - include_dirs=image_include + include_dirs + [image_path], - library_dirs=image_library + library_dirs, - define_macros=image_macros, - libraries=image_link_flags, - extra_compile_args=extra_compile_args - )) - - ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') + ext_modules.append( + extension( + "torchvision.image", + image_src, + include_dirs=image_include + include_dirs + [image_path], + library_dirs=image_library + library_dirs, + define_macros=image_macros, + libraries=image_link_flags, + extra_compile_args=extra_compile_args, + ) + ) + + ffmpeg_exe = distutils.spawn.find_executable("ffmpeg") has_ffmpeg = ffmpeg_exe is not None if has_ffmpeg: try: ffmpeg_version = subprocess.run( - 'ffmpeg -version | head -n1', shell=True, - stdout=subprocess.PIPE).stdout.decode('utf-8') - ffmpeg_version = ffmpeg_version.split('version')[-1].split()[0] - if StrictVersion(ffmpeg_version) >= StrictVersion('4.3'): - print(f'ffmpeg {ffmpeg_version} not supported yet, please use ffmpeg 4.2.') + "ffmpeg -version | head -n1", shell=True, stdout=subprocess.PIPE + ).stdout.decode("utf-8") + ffmpeg_version = ffmpeg_version.split("version")[-1].split()[0] + if StrictVersion(ffmpeg_version) >= StrictVersion("4.3"): + print(f"ffmpeg {ffmpeg_version} not supported yet, please use ffmpeg 4.2.") has_ffmpeg = False except (IndexError, ValueError): - print('Error fetching ffmpeg version, ignoring ffmpeg.') + print("Error fetching ffmpeg version, ignoring ffmpeg.") has_ffmpeg = False print("FFmpeg found: {}".format(has_ffmpeg)) if has_ffmpeg: - ffmpeg_libraries = { - 'libavcodec', - 'libavformat', - 'libavutil', - 'libswresample', - 'libswscale' - } + ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"} ffmpeg_bin = os.path.dirname(ffmpeg_exe) ffmpeg_root = os.path.dirname(ffmpeg_bin) - ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include') - ffmpeg_library_dir = os.path.join(ffmpeg_root, 'lib') + ffmpeg_include_dir = os.path.join(ffmpeg_root, "include") + ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib") - gcc = distutils.spawn.find_executable('gcc') - platform_tag = subprocess.run( - [gcc, '-print-multiarch'], stdout=subprocess.PIPE) - platform_tag = platform_tag.stdout.strip().decode('utf-8') + gcc = distutils.spawn.find_executable("gcc") + platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE) + platform_tag = platform_tag.stdout.strip().decode("utf-8") if platform_tag: # Most probably a Debian-based distribution - ffmpeg_include_dir = [ - ffmpeg_include_dir, - os.path.join(ffmpeg_include_dir, platform_tag) - ] - ffmpeg_library_dir = [ - ffmpeg_library_dir, - os.path.join(ffmpeg_library_dir, platform_tag) - ] + ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)] + ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)] else: ffmpeg_include_dir = [ffmpeg_include_dir] ffmpeg_library_dir = [ffmpeg_library_dir] @@ -400,11 +379,11 @@ def get_extensions(): for library in ffmpeg_libraries: library_found = False for search_path in ffmpeg_include_dir + include_dirs: - full_path = os.path.join(search_path, library, '*.h') + full_path = os.path.join(search_path, library, "*.h") library_found |= len(glob.glob(full_path)) > 0 if not library_found: - print(f'{library} header files were not found, disabling ffmpeg support') + print(f"{library} header files were not found, disabling ffmpeg support") has_ffmpeg = False if has_ffmpeg: @@ -412,22 +391,21 @@ def get_extensions(): print("ffmpeg library_dir: {}".format(ffmpeg_library_dir)) # TorchVision base decoder + video reader - video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video_reader') + video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader") video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp")) - base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'decoder') - base_decoder_src = glob.glob( - os.path.join(base_decoder_src_dir, "*.cpp")) + base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder") + base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp")) # Torchvision video API - videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video') + videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video") videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp")) # exclude tests - base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x] + base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x] combined_src = video_reader_src + base_decoder_src + videoapi_src ext_modules.append( CppExtension( - 'torchvision.video_reader', + "torchvision.video_reader", combined_src, include_dirs=[ base_decoder_src_dir, @@ -435,18 +413,18 @@ def get_extensions(): videoapi_src_dir, extensions_dir, *ffmpeg_include_dir, - *include_dirs + *include_dirs, ], library_dirs=ffmpeg_library_dir + library_dirs, libraries=[ - 'avcodec', - 'avformat', - 'avutil', - 'swresample', - 'swscale', + "avcodec", + "avformat", + "avutil", + "swresample", + "swscale", ], - extra_compile_args=["-std=c++14"] if os.name != 'nt' else ['/std:c++14', '/MP'], - extra_link_args=["-std=c++14" if os.name != 'nt' else '/std:c++14'], + extra_compile_args=["-std=c++14"] if os.name != "nt" else ["/std:c++14", "/MP"], + extra_link_args=["-std=c++14" if os.name != "nt" else "/std:c++14"], ) ) @@ -455,9 +433,9 @@ def get_extensions(): class clean(distutils.command.clean.clean): def run(self): - with open('.gitignore', 'r') as f: + with open(".gitignore", "r") as f: ignores = f.read() - for wildcard in filter(None, ignores.split('\n')): + for wildcard in filter(None, ignores.split("\n")): for filename in glob.glob(wildcard): try: os.remove(filename) @@ -473,25 +451,22 @@ def run(self): write_version_file() - with open('README.rst') as f: + with open("README.rst") as f: readme = f.read() setup( # Metadata name=package_name, version=version, - author='PyTorch Core Team', - author_email='soumith@pytorch.org', - url='https://github.com/pytorch/vision', - description='image and video datasets and models for torch deep learning', + author="PyTorch Core Team", + author_email="soumith@pytorch.org", + url="https://github.com/pytorch/vision", + description="image and video datasets and models for torch deep learning", long_description=readme, - license='BSD', - + license="BSD", # Package info - packages=find_packages(exclude=('test',)), - package_data={ - package_name: ['*.dll', '*.dylib', '*.so'] - }, + packages=find_packages(exclude=("test",)), + package_data={package_name: ["*.dll", "*.dylib", "*.so"]}, zip_safe=False, install_requires=requirements, extras_require={ @@ -499,7 +474,7 @@ def run(self): }, ext_modules=get_extensions(), cmdclass={ - 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True), - 'clean': clean, - } + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + "clean": clean, + }, ) diff --git a/test/common_utils.py b/test/common_utils.py index 1da5226f425..75295b6c446 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,30 +1,27 @@ +import contextlib +import functools import os +import random import shutil +import sys import tempfile -import contextlib import unittest -import argparse -import sys -import torch -import __main__ -import random -import inspect -import functools - -from numbers import Number -from torch._six import string_classes from collections import OrderedDict +from numbers import Number import numpy as np from PIL import Image +import torch +from torch._six import string_classes + IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) -IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' +IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true" IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" -CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' +CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda." @@ -88,27 +85,26 @@ def is_iterable(obj): class TestCase(unittest.TestCase): precision = 1e-5 - def assertEqual(self, x, y, prec=None, message='', allow_inf=False): + def assertEqual(self, x, y, prec=None, message="", allow_inf=False): """ This is copied from pytorch/test/common_utils.py's TestCase.assertEqual """ - if isinstance(prec, str) and message == '': + if isinstance(prec, str) and message == "": message = prec prec = None if prec is None: prec = self.precision if isinstance(x, torch.Tensor) and isinstance(y, Number): - self.assertEqual(x.item(), y, prec=prec, message=message, - allow_inf=allow_inf) + self.assertEqual(x.item(), y, prec=prec, message=message, allow_inf=allow_inf) elif isinstance(y, torch.Tensor) and isinstance(x, Number): - self.assertEqual(x, y.item(), prec=prec, message=message, - allow_inf=allow_inf) + self.assertEqual(x, y.item(), prec=prec, message=message, allow_inf=allow_inf) elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + def assertTensorsEqual(a, b): super(TestCase, self).assertEqual(a.size(), b.size(), message) if a.numel() > 0: - if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)): + if a.device.type == "cpu" and (a.dtype == torch.float16 or a.dtype == torch.bfloat16): # CPU half and bfloat16 tensors don't have the methods we need below a = a.to(torch.float32) b = b.to(a) @@ -140,6 +136,7 @@ def assertTensorsEqual(a, b): max_err = diff.max() tolerance = prec + prec * abs(a.max()) self.assertLessEqual(max_err, tolerance, message) + super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) if x.is_sparse: @@ -148,26 +145,36 @@ def assertTensorsEqual(a, b): assertTensorsEqual(x._indices(), y._indices()) assertTensorsEqual(x._values(), y._values()) elif x.is_quantized and y.is_quantized: - self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, - message=message, allow_inf=allow_inf) + self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, message=message, allow_inf=allow_inf) if x.qscheme() == torch.per_tensor_affine: - self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, - message=message, allow_inf=allow_inf) - self.assertEqual(x.q_zero_point(), y.q_zero_point(), - prec=prec, message=message, - allow_inf=allow_inf) + self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, message=message, allow_inf=allow_inf) + self.assertEqual( + x.q_zero_point(), y.q_zero_point(), prec=prec, message=message, allow_inf=allow_inf + ) elif x.qscheme() == torch.per_channel_affine: - self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec, - message=message, allow_inf=allow_inf) - self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(), - prec=prec, message=message, - allow_inf=allow_inf) - self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), - prec=prec, message=message) + self.assertEqual( + x.q_per_channel_scales(), + y.q_per_channel_scales(), + prec=prec, + message=message, + allow_inf=allow_inf, + ) + self.assertEqual( + x.q_per_channel_zero_points(), + y.q_per_channel_zero_points(), + prec=prec, + message=message, + allow_inf=allow_inf, + ) + self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), prec=prec, message=message) self.assertEqual(x.dtype, y.dtype) - self.assertEqual(x.int_repr().to(torch.int32), - y.int_repr().to(torch.int32), prec=prec, - message=message, allow_inf=allow_inf) + self.assertEqual( + x.int_repr().to(torch.int32), + y.int_repr().to(torch.int32), + prec=prec, + message=message, + allow_inf=allow_inf, + ) else: assertTensorsEqual(x, y) elif isinstance(x, string_classes) and isinstance(y, string_classes): @@ -176,21 +183,17 @@ def assertTensorsEqual(a, b): super(TestCase, self).assertEqual(x, y, message) elif isinstance(x, dict) and isinstance(y, dict): if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): - self.assertEqual(x.items(), y.items(), prec=prec, - message=message, allow_inf=allow_inf) + self.assertEqual(x.items(), y.items(), prec=prec, message=message, allow_inf=allow_inf) else: - self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, - message=message, allow_inf=allow_inf) + self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, message=message, allow_inf=allow_inf) key_list = list(x.keys()) - self.assertEqual([x[k] for k in key_list], - [y[k] for k in key_list], - prec=prec, message=message, - allow_inf=allow_inf) + self.assertEqual( + [x[k] for k in key_list], [y[k] for k in key_list], prec=prec, message=message, allow_inf=allow_inf + ) elif is_iterable(x) and is_iterable(y): super(TestCase, self).assertEqual(len(x), len(y), message) for x_, y_ in zip(x, y): - self.assertEqual(x_, y_, prec=prec, message=message, - allow_inf=allow_inf) + self.assertEqual(x_, y_, prec=prec, message=message, allow_inf=allow_inf) elif isinstance(x, bool) and isinstance(y, bool): super(TestCase, self).assertEqual(x, y, message) elif isinstance(x, Number) and isinstance(y, Number): @@ -219,7 +222,7 @@ def freeze_rng_state(): def cycle_over(objs): for idx, obj1 in enumerate(objs): - for obj2 in objs[:idx] + objs[idx + 1:]: + for obj2 in objs[:idx] + objs[idx + 1 :]: yield obj1, obj2 @@ -241,11 +244,13 @@ def disable_console_output(): def cpu_and_gpu(): import pytest # noqa - return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda)) + + return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) def needs_cuda(test_func): import pytest # noqa + return pytest.mark.needs_cuda(test_func) @@ -258,12 +263,7 @@ def _create_data(height=3, width=3, channels=3, device="cpu"): def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture - batch_tensor = torch.randint( - 0, 256, - (num_samples, channels, height, width), - dtype=torch.uint8, - device=device - ) + batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device) return batch_tensor @@ -280,8 +280,9 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): assert_equal(tensor.cpu(), pil_tensor, msg=msg) -def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", - allowed_percentage_diff=None): +def _assert_approx_equal_tensor_to_pil( + tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None +): # TODO: we could just merge this into _assert_equal_tensor_to_pil np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: diff --git a/test/conftest.py b/test/conftest.py index 3cffeeac88f..23e39f93083 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,16 +1,14 @@ -from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG -import torch import pytest +import torch + +from common_utils import CIRCLECI_GPU_NO_CUDA_MSG, CUDA_NOT_AVAILABLE_MSG, IN_CIRCLE_CI, IN_FBCODE, IN_RE_WORKER + def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) - config.addinivalue_line( - "markers", "needs_cuda: mark for tests that rely on a CUDA device" - ) - config.addinivalue_line( - "markers", "dont_collect: mark for tests that should not be collected" - ) + config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") + config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") def pytest_collection_modifyitems(items): @@ -32,7 +30,7 @@ def pytest_collection_modifyitems(items): # @pytest.mark.parametrize('device', cpu_and_gpu()) # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # and the ones with device == 'cpu' won't have the mark. - needs_cuda = item.get_closest_marker('needs_cuda') is not None + needs_cuda = item.get_closest_marker("needs_cuda") is not None if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU @@ -57,7 +55,7 @@ def pytest_collection_modifyitems(items): # to run the CPU-only tests. item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)) - if item.get_closest_marker('dont_collect') is not None: + if item.get_closest_marker("dont_collect") is not None: # currently, this is only used for some tests we're sure we dont want to run on fbcode continue diff --git a/test/datasets_utils.py b/test/datasets_utils.py index d7853b46314..314ba7039ad 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -14,12 +14,12 @@ import PIL import PIL.Image + import torch import torchvision.datasets import torchvision.io -from common_utils import get_tmp_dir, disable_console_output - +from common_utils import disable_console_output, get_tmp_dir __all__ = [ "UsageError", @@ -418,7 +418,7 @@ def _populate_private_class_attributes(cls): defaults.append( { kwarg: default - for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults) + for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults) if not kwarg.startswith("_") } ) @@ -641,7 +641,7 @@ def __init__(self, *args, **kwargs): def _set_default_frames_per_clip(self, inject_fake_data): argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) - args_without_default = argspec.args[1:(-len(argspec.defaults) if argspec.defaults else None)] + args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)] frames_per_clip_last = args_without_default[-1] == "frames_per_clip" @functools.wraps(inject_fake_data) diff --git a/test/preprocess-bench.py b/test/preprocess-bench.py index 4ba3ca46dbc..84191fc40bd 100644 --- a/test/preprocess-bench.py +++ b/test/preprocess-bench.py @@ -1,47 +1,49 @@ import argparse import os from timeit import default_timer as timer -from torch.utils.model_zoo import tqdm + import torch import torch.utils.data import torchvision -import torchvision.transforms as transforms import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch.utils.model_zoo import tqdm - -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('--data', metavar='PATH', required=True, - help='path to dataset') -parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', - help='number of data loading threads (default: 2)') -parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', - help='mini-batch size (1 = pure stochastic) Default: 256') -parser.add_argument('--accimage', action='store_true', - help='use accimage') +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument("--data", metavar="PATH", required=True, help="path to dataset") +parser.add_argument( + "--nThreads", "-j", default=2, type=int, metavar="N", help="number of data loading threads (default: 2)" +) +parser.add_argument( + "--batchSize", "-b", default=256, type=int, metavar="N", help="mini-batch size (1 = pure stochastic) Default: 256" +) +parser.add_argument("--accimage", action="store_true", help="use accimage") if __name__ == "__main__": args = parser.parse_args() if args.accimage: - torchvision.set_image_backend('accimage') - print('Using {}'.format(torchvision.get_image_backend())) + torchvision.set_image_backend("accimage") + print("Using {}".format(torchvision.get_image_backend())) # Data loading code - transform = transforms.Compose([ - transforms.RandomSizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') + transform = transforms.Compose( + [ + transforms.RandomSizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + traindir = os.path.join(args.data, "train") + valdir = os.path.join(args.data, "val") train = datasets.ImageFolder(traindir, transform) val = datasets.ImageFolder(valdir, transform) train_loader = torch.utils.data.DataLoader( - train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) + train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads + ) train_iter = iter(train_loader) start_time = timer() @@ -51,9 +53,12 @@ pbar.update(1) batch = next(train_iter) end_time = timer() - print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch," - " {image:.2f} ms/image {rate:.0f} images/sec" - .format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), - batch=(end_time - start_time) / float(batch_count) * 1.0e+3, - image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3, - rate=(batch_count * args.batchSize) / (end_time - start_time))) + print( + "Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch," + " {image:.2f} ms/image {rate:.0f} images/sec".format( + dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), + batch=(end_time - start_time) / float(batch_count) * 1.0e3, + image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e3, + rate=(batch_count * args.batchSize) / (end_time - start_time), + ) + ) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 712dccf11a8..36fff446506 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,11 +1,11 @@ +import pytest + import torch from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -import pytest - -@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) +@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): - x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) - assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] + assert list(y.keys()) == ["0", "1", "2", "3", "pool"] diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index 6deb5d79739..0d66b045744 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -1,11 +1,12 @@ -import torch import os -import unittest -from torchvision import models, transforms import sys +import unittest from PIL import Image + +import torch import torchvision.transforms.functional as F +from torchvision import models try: from torchvision import _C_tests @@ -21,12 +22,13 @@ def process_model(model, tensor, func, name): py_output = model.forward(tensor) cpp_output = func("model.pt", tensor) - assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models' + assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models" def read_image1(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', - 'grace_hopper_517x606.jpg') + image_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" + ) image = Image.open(image_path) image = image.resize((224, 224)) x = F.to_tensor(image) @@ -34,8 +36,9 @@ def read_image1(): def read_image2(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', - 'grace_hopper_517x606.jpg') + image_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" + ) image = Image.open(image_path) image = image.resize((299, 299)) x = F.to_tensor(image) @@ -46,107 +49,110 @@ def read_image2(): @unittest.skipIf( sys.platform == "darwin" or True, "C++ models are broken on OS X at the moment, and there's a BC breakage on master; " - "see https://github.com/pytorch/vision/issues/1191") + "see https://github.com/pytorch/vision/issues/1191", +) class Tester(unittest.TestCase): pretrained = False image = read_image1() def test_alexnet(self): - process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet') + process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet") def test_vgg11(self): - process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11') + process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11") def test_vgg13(self): - process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13') + process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13") def test_vgg16(self): - process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16') + process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16") def test_vgg19(self): - process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19') + process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19") def test_vgg11_bn(self): - process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN') + process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN") def test_vgg13_bn(self): - process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN') + process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN") def test_vgg16_bn(self): - process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN') + process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN") def test_vgg19_bn(self): - process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN') + process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN") def test_resnet18(self): - process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18') + process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18") def test_resnet34(self): - process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34') + process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34") def test_resnet50(self): - process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50') + process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50") def test_resnet101(self): - process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101') + process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101") def test_resnet152(self): - process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152') + process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152") def test_resnext50_32x4d(self): - process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d') + process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d") def test_resnext101_32x8d(self): - process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') + process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d") def test_wide_resnet50_2(self): - process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2') + process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2") def test_wide_resnet101_2(self): - process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2') + process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2") def test_squeezenet1_0(self): - process_model(models.squeezenet1_0(self.pretrained), self.image, - _C_tests.forward_squeezenet1_0, 'Squeezenet1.0') + process_model( + models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0" + ) def test_squeezenet1_1(self): - process_model(models.squeezenet1_1(self.pretrained), self.image, - _C_tests.forward_squeezenet1_1, 'Squeezenet1.1') + process_model( + models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1" + ) def test_densenet121(self): - process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121') + process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121") def test_densenet169(self): - process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169') + process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169") def test_densenet201(self): - process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201') + process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201") def test_densenet161(self): - process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161') + process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161") def test_mobilenet_v2(self): - process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet') + process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet") def test_googlenet(self): - process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet') + process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet") def test_mnasnet0_5(self): - process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, 'MNASNet0_5') + process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") def test_mnasnet0_75(self): - process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, 'MNASNet0_75') + process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") def test_mnasnet1_0(self): - process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, 'MNASNet1_0') + process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") def test_mnasnet1_3(self): - process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, 'MNASNet1_3') + process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") def test_inception_v3(self): self.image = read_image2() - process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3') + process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_datasets.py b/test/test_datasets.py index 043398a0ca6..9c298f46894 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2,10 +2,10 @@ import contextlib import io import itertools +import json import os import pathlib import pickle -import json import random import shutil import string @@ -13,9 +13,10 @@ import xml.etree.ElementTree as ET import zipfile -import PIL import datasets_utils import numpy as np +import PIL + import torch import torch.nn.functional as F from torchvision import datasets @@ -23,8 +24,7 @@ class STL10TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.STL10 - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=("train", "test", "unlabeled", "train+unlabeled")) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled")) @staticmethod def _make_binary_file(num_elements, root, name): @@ -209,11 +209,11 @@ def inject_fake_data(self, tmpdir, config): class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.WIDERFace FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test')) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) def inject_fake_data(self, tmpdir, config): - widerface_dir = pathlib.Path(tmpdir) / 'widerface' - annotations_dir = widerface_dir / 'wider_face_split' + widerface_dir = pathlib.Path(tmpdir) / "widerface" + annotations_dir = widerface_dir / "wider_face_split" os.makedirs(annotations_dir) split_to_idx = split_to_num_examples = { @@ -223,21 +223,21 @@ def inject_fake_data(self, tmpdir, config): } # We need to create all folders regardless of the split in config - for split in ('train', 'val', 'test'): + for split in ("train", "val", "test"): split_idx = split_to_idx[split] num_examples = split_to_num_examples[split] datasets_utils.create_image_folder( root=tmpdir, - name=widerface_dir / f'WIDER_{split}' / 'images' / '0--Parade', + name=widerface_dir / f"WIDER_{split}" / "images" / "0--Parade", file_name_fn=lambda image_idx: f"0_Parade_marchingband_1_{split_idx + image_idx}.jpg", num_examples=num_examples, ) annotation_file_name = { - 'train': annotations_dir / 'wider_face_train_bbx_gt.txt', - 'val': annotations_dir / 'wider_face_val_bbx_gt.txt', - 'test': annotations_dir / 'wider_face_test_filelist.txt', + "train": annotations_dir / "wider_face_train_bbx_gt.txt", + "val": annotations_dir / "wider_face_val_bbx_gt.txt", + "test": annotations_dir / "wider_face_test_filelist.txt", }[split] annotation_content = { @@ -270,9 +270,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): "color", ) ADDITIONAL_CONFIGS = ( - *datasets_utils.combinations_grid( - mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES - ), + *datasets_utils.combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES), *datasets_utils.combinations_grid( mode=("coarse",), split=("train", "train_extra", "val"), @@ -327,6 +325,7 @@ def inject_fake_data(self, tmpdir, config): gt_dir = tmpdir / f"gt{mode}" for split in mode_to_splits[mode]: for city in cities: + def make_image(name, size=10): datasets_utils.create_image_folder( root=gt_dir / split, @@ -335,6 +334,7 @@ def make_image(name, size=10): size=size, num_examples=1, ) + make_image(f"{city}_000000_000000_gt{mode}_instanceIds.png") make_image(f"{city}_000000_000000_gt{mode}_labelIds.png") make_image(f"{city}_000000_000000_gt{mode}_color.png", size=(4, 10, 10)) @@ -344,7 +344,7 @@ def make_image(name, size=10): json.dump(polygon_target, outfile) # Create leftImg8bit folder - for split in ['test', 'train_extra', 'train', 'val']: + for split in ["test", "train_extra", "train", "val"]: for city in cities: datasets_utils.create_image_folder( root=tmpdir / "leftImg8bit" / split, @@ -353,13 +353,13 @@ def make_image(name, size=10): num_examples=1, ) - info = {'num_examples': len(cities)} - if config['target_type'] == 'polygon': - info['expected_polygon_target'] = polygon_target + info = {"num_examples": len(cities)} + if config["target_type"] == "polygon": + info["expected_polygon_target"] = polygon_target return info def test_combined_targets(self): - target_types = ['semantic', 'polygon', 'color'] + target_types = ["semantic", "polygon", "color"] with self.create_dataset(target_type=target_types) as (dataset, _): output = dataset[0] @@ -373,32 +373,32 @@ def test_combined_targets(self): self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color def test_feature_types_target_color(self): - with self.create_dataset(target_type='color') as (dataset, _): + with self.create_dataset(target_type="color") as (dataset, _): color_img, color_target = dataset[0] self.assertTrue(isinstance(color_img, PIL.Image.Image)) self.assertTrue(np.array(color_target).shape[2] == 4) def test_feature_types_target_polygon(self): - with self.create_dataset(target_type='polygon') as (dataset, info): + with self.create_dataset(target_type="polygon") as (dataset, info): polygon_img, polygon_target = dataset[0] self.assertTrue(isinstance(polygon_img, PIL.Image.Image)) - self.assertEqual(polygon_target, info['expected_polygon_target']) + self.assertEqual(polygon_target, info["expected_polygon_target"]) class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageNet - REQUIRED_PACKAGES = ('scipy',) - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) + REQUIRED_PACKAGES = ("scipy",) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) - wnid = 'n01234567' - if config['split'] == 'train': + wnid = "n01234567" + if config["split"] == "train": num_examples = 3 datasets_utils.create_image_folder( root=tmpdir, - name=tmpdir / 'train' / wnid / wnid, + name=tmpdir / "train" / wnid / wnid, file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", num_examples=num_examples, ) @@ -406,13 +406,13 @@ def inject_fake_data(self, tmpdir, config): num_examples = 1 datasets_utils.create_image_folder( root=tmpdir, - name=tmpdir / 'val' / wnid, + name=tmpdir / "val" / wnid, file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG", num_examples=num_examples, ) wnid_to_classes = {wnid: [1]} - torch.save((wnid_to_classes, None), tmpdir / 'meta.bin') + torch.save((wnid_to_classes, None), tmpdir / "meta.bin") return num_examples @@ -883,10 +883,7 @@ def inject_fake_data(self, tmpdir, config): return num_images @contextlib.contextmanager - def create_dataset( - self, - *args, **kwargs - ): + def create_dataset(self, *args, **kwargs): with super().create_dataset(*args, **kwargs) as output: yield output # Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus, @@ -946,14 +943,12 @@ def test_not_found_or_corrupted(self): class KineticsTestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.Kinetics - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=("train", "val"), num_classes=("400", "600", "700") - ) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"), num_classes=("400", "600", "700")) def inject_fake_data(self, tmpdir, config): classes = ("Abseiling", "Zumba") num_videos_per_class = 2 - tmpdir = pathlib.Path(tmpdir) / config['split'] + tmpdir = pathlib.Path(tmpdir) / config["split"] digits = string.ascii_letters + string.digits + "-_" for cls in classes: datasets_utils.create_video_folder( @@ -1576,7 +1571,7 @@ def test_is_valid_file(self, config): # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the # DEFAULT_CONFIG. with self.create_dataset( - config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions ) as (dataset, info): self.assertEqual(len(dataset), info["num_examples"]) @@ -1660,7 +1655,7 @@ def inject_fake_data(self, tmpdir, config): file = f"{split}_32x32.mat" images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8) targets = np.zeros((num_examples,), dtype=np.uint8) - sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets}) + sio.savemat(os.path.join(tmpdir, file), {"X": images, "y": targets}) return num_examples @@ -1695,8 +1690,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): # (file, idx) _FILE_LIST_CONTENT = ( ("Places365_val_00000001.png", 0), - *((f"{category}/Places365_train_00000001.png", idx) - for category, idx in _CATEGORIES_CONTENT), + *((f"{category}/Places365_train_00000001.png", idx) for category, idx in _CATEGORIES_CONTENT), ) @staticmethod @@ -1736,8 +1730,8 @@ def _make_images_archive(root, split, small): return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] def inject_fake_data(self, tmpdir, config): - self._make_devkit_archive(tmpdir, config['split']) - return len(self._make_images_archive(tmpdir, config['split'], config['small'])) + self._make_devkit_archive(tmpdir, config["split"]) + return len(self._make_images_archive(tmpdir, config["split"], config["small"])) def test_classes(self): classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) @@ -1751,7 +1745,7 @@ def test_class_to_idx(self): def test_images_download_preexisting(self): with self.assertRaises(RuntimeError): - with self.create_dataset({'download': True}): + with self.create_dataset({"download": True}): pass diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 8c2d575e01d..53077477892 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -1,25 +1,25 @@ import contextlib import itertools +import tempfile import time import unittest.mock +import warnings from datetime import datetime from distutils import dir_util from os import path from urllib.error import HTTPError, URLError from urllib.parse import urlparse -from urllib.request import urlopen, Request -import tempfile -import warnings +from urllib.request import Request, urlopen import pytest from torchvision import datasets from torchvision.datasets.utils import ( - download_url, + USER_AGENT, + _get_redirect_url, check_integrity, download_file_from_google_drive, - _get_redirect_url, - USER_AGENT, + download_url, ) from common_utils import get_tmp_dir diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 7754c1a98e8..bda2a6570d0 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,19 +1,14 @@ import contextlib -import sys import os -import torch + import pytest +import torch from torchvision import io -from torchvision.datasets.samplers import ( - DistributedSampler, - RandomClipSampler, - UniformClipSampler, -) -from torchvision.datasets.video_utils import VideoClips, unfold -from torchvision import get_video_backend +from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler +from torchvision.datasets.video_utils import VideoClips -from common_utils import get_tmp_dir, assert_equal +from common_utils import assert_equal, get_tmp_dir @contextlib.contextmanager @@ -45,7 +40,7 @@ def test_random_clip_sampler(self): sampler = RandomClipSampler(video_clips, 3) assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(count, torch.tensor([3, 3, 3])) @@ -62,7 +57,7 @@ def test_random_clip_sampler_unequal(self): indices.remove(0) indices.remove(1) indices = torch.tensor(indices) - 2 - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1])) assert_equal(count, torch.tensor([3, 3])) @@ -73,7 +68,7 @@ def test_uniform_clip_sampler(self): sampler = UniformClipSampler(video_clips, 3) assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(count, torch.tensor([3, 3, 3])) @@ -113,5 +108,5 @@ def test_distributed_sampler_and_uniform_clip_sampler(self): assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 34ca3da6847..867ca97f28c 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,38 +1,34 @@ -import bz2 +import gzip import os -import torchvision.datasets.utils as utils -import pytest -import zipfile import tarfile -import gzip -import warnings -from torch._utils_internal import get_file_path_2 -from urllib.error import URLError -import itertools -import lzma +import zipfile -from common_utils import get_tmp_dir +import pytest + +import torchvision.datasets.utils as utils +from torch._utils_internal import get_file_path_2 from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS +from common_utils import get_tmp_dir TEST_FILE = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" +) class TestDatasetsUtils: - def test_check_md5(self): fpath = TEST_FILE - correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' - false_md5 = '' + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" assert utils.check_md5(fpath, correct_md5) assert not utils.check_md5(fpath, false_md5) def test_check_integrity(self): existing_fpath = TEST_FILE - nonexisting_fpath = '' - correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' - false_md5 = '' + nonexisting_fpath = "" + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" assert utils.check_integrity(existing_fpath, correct_md5) assert not utils.check_integrity(existing_fpath, false_md5) assert utils.check_integrity(existing_fpath) @@ -50,31 +46,35 @@ def test_get_google_drive_file_id_invalid_url(self): assert utils._get_google_drive_file_id(url) is None - @pytest.mark.parametrize('file, expected', [ - ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), - ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), - ("foo.tar", (".tar", ".tar", None)), - ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.tbz", (".tbz", ".tar", ".bz2")), - ("foo.tbz2", (".tbz2", ".tar", ".bz2")), - ("foo.tgz", (".tgz", ".tar", ".gz")), - ("foo.bz2", (".bz2", None, ".bz2")), - ("foo.gz", (".gz", None, ".gz")), - ("foo.zip", (".zip", ".zip", None)), - ("foo.xz", (".xz", None, ".xz")), - ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.bar.gz", (".gz", None, ".gz")), - ("foo.bar.zip", (".zip", ".zip", None))]) + @pytest.mark.parametrize( + "file, expected", + [ + ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tbz", (".tbz", ".tar", ".bz2")), + ("foo.tbz2", (".tbz2", ".tar", ".bz2")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.bz2", (".bz2", None, ".bz2")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), + ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.bar.gz", (".gz", None, ".gz")), + ("foo.bar.zip", (".zip", ".zip", None)), + ], + ) def test_detect_file_type(self, file, expected): assert utils._detect_file_type(file) == expected - @pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"]) + @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"]) def test_detect_file_type_incompatible(self, file): # tests detect file type for no extension, unknown compression and unknown partial extension with pytest.raises(RuntimeError): utils._detect_file_type(file) - @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) + @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"]) def test_decompress(self, extension): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") @@ -117,8 +117,8 @@ def create_compressed(root, content="this is the content"): assert not os.path.exists(compressed) - @pytest.mark.parametrize('extension', [".gz", ".xz"]) - @pytest.mark.parametrize('remove_finished', [True, False]) + @pytest.mark.parametrize("extension", [".gz", ".xz"]) + @pytest.mark.parametrize("remove_finished", [True, False]) def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" file = f"{filename}{extension}" @@ -148,8 +148,9 @@ def create_archive(root, content="this is the content"): with open(file, "r") as fh: assert fh.read() == content - @pytest.mark.parametrize('extension, mode', [ - ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) + @pytest.mark.parametrize( + "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")] + ) def test_extract_tar(self, extension, mode): def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") @@ -180,5 +181,5 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 00db0aad127..ac1a1fffa1c 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,12 +1,13 @@ import contextlib import os -import torch + import pytest +import torch from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold -from common_utils import get_tmp_dir, assert_equal +from common_utils import assert_equal, get_tmp_dir @contextlib.contextmanager @@ -31,30 +32,29 @@ def get_list_of_videos(num_videos=5, sizes=None, fps=None): class TestVideo: - def test_unfold(self): a = torch.arange(7) r = unfold(a, 3, 3, 1) - expected = torch.tensor([ - [0, 1, 2], - [3, 4, 5], - ]) + expected = torch.tensor( + [ + [0, 1, 2], + [3, 4, 5], + ] + ) assert_equal(r, expected) r = unfold(a, 3, 2, 1) - expected = torch.tensor([ - [0, 1, 2], - [2, 3, 4], - [4, 5, 6] - ]) + expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]]) assert_equal(r, expected) r = unfold(a, 3, 2, 2) - expected = torch.tensor([ - [0, 2, 4], - [2, 4, 6], - ]) + expected = torch.tensor( + [ + [0, 2, 4], + [2, 4, 6], + ] + ) assert_equal(r, expected) @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @@ -100,8 +100,7 @@ def test_compute_clips_for_video(self): orig_fps = 30 duration = float(len(video_pts)) / orig_fps new_fps = 13 - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) assert len(clips) == 1 assert_equal(clips, idxs) @@ -112,8 +111,7 @@ def test_compute_clips_for_video(self): orig_fps = 30 duration = float(len(video_pts)) / orig_fps new_fps = 12 - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) assert len(clips) == 3 assert_equal(clips, idxs) @@ -124,11 +122,10 @@ def test_compute_clips_for_video(self): orig_fps = 30 new_fps = 13 with pytest.warns(UserWarning): - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) assert len(clips) == 0 assert len(idxs) == 0 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_video_utils_opt.py b/test/test_datasets_video_utils_opt.py index 8075c701ed9..dfa0eccb155 100644 --- a/test/test_datasets_video_utils_opt.py +++ b/test/test_datasets_video_utils_opt.py @@ -1,11 +1,13 @@ import unittest -from torchvision import set_video_backend + import test_datasets_video_utils +from torchvision import set_video_backend # noqa: F401 + # Disabling the video backend switching temporarily # set_video_backend('video_reader') -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils) unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 5ce82304569..baeb80e93a9 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,32 +1,30 @@ -import itertools -import os import colorsys +import itertools import math +import os +from typing import Sequence import numpy as np import pytest import torch -import torchvision.transforms.functional_tensor as F_t -import torchvision.transforms.functional_pil as F_pil -import torchvision.transforms.functional as F import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.functional_pil as F_pil +import torchvision.transforms.functional_tensor as F_t from torchvision.transforms import InterpolationMode from common_utils import ( - cpu_and_gpu, - needs_cuda, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, _test_fn_on_batch, assert_equal, + cpu_and_gpu, + needs_cuda, ) -from typing import Dict, List, Sequence, Tuple - - NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC @@ -38,10 +36,10 @@ def test_scale_channel(): # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed, # only use bincount and remove that test. size = (1_000,) - img_chan = torch.randint(0, 256, size=size).to('cpu') + img_chan = torch.randint(0, 256, size=size).to("cpu") scaled_cpu = F_t._scale_channel(img_chan) - scaled_cuda = F_t._scale_channel(img_chan.to('cuda')) - assert_equal(scaled_cpu, scaled_cuda.to('cpu')) + scaled_cuda = F_t._scale_channel(img_chan.to("cuda")) + assert_equal(scaled_cpu, scaled_cuda.to("cpu")) class TestRotate: @@ -50,18 +48,33 @@ class TestRotate: scripted_rotate = torch.jit.script(F.rotate) IMG_W = 26 - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, IMG_W), (32, IMG_W)]) - @pytest.mark.parametrize('center', [ - None, - (int(IMG_W * 0.3), int(IMG_W * 0.4)), - [int(IMG_W * 0.5), int(IMG_W * 0.6)], - ]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle', range(-180, 180, 17)) - @pytest.mark.parametrize('expand', [True, False]) - @pytest.mark.parametrize('fill', [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]) - @pytest.mark.parametrize('fn', [F.rotate, scripted_rotate]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)]) + @pytest.mark.parametrize( + "center", + [ + None, + (int(IMG_W * 0.3), int(IMG_W * 0.4)), + [int(IMG_W * 0.5), int(IMG_W * 0.6)], + ], + ) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("angle", range(-180, 180, 17)) + @pytest.mark.parametrize("expand", [True, False]) + @pytest.mark.parametrize( + "fill", + [ + None, + [0, 0, 0], + (1, 2, 3), + [255, 255, 255], + [ + 1, + ], + (2.0,), + ], + ) + @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate]) def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn): tensor, pil_img = _create_data(height, width, device=device) @@ -82,8 +95,8 @@ def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn out_tensor = out_tensor.to(torch.uint8) assert out_tensor.shape == out_pil_tensor.shape, ( - f"{(height, width, NEAREST, dt, angle, expand, center)}: " - f"{out_tensor.shape} vs {out_pil_tensor.shape}") + f"{(height, width, NEAREST, dt, angle, expand, center)}: " f"{out_tensor.shape} vs {out_pil_tensor.shape}" + ) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] @@ -91,10 +104,11 @@ def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn assert ratio_diff_pixels < 0.03, ( f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: " f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n" - f"{out_pil_tensor[0, :7, :7]}") + f"{out_pil_tensor[0, :7, :7]}" + ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_rotate_batch(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -105,9 +119,7 @@ def test_rotate_batch(self, device, dt): batch_tensors = batch_tensors.to(dtype=dt) center = (20, 22) - _test_fn_on_batch( - batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center - ) + _test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center) def test_rotate_deprecation_resample(self): tensor, _ = _create_data(26, 26) @@ -131,9 +143,9 @@ class TestAffine: ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16] scripted_affine = torch.jit.script(F.affine) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_identity_map(self, device, height, width, dt): # Tests on square and rectangular images tensor, pil_img = _create_data(height, width, device=device) @@ -154,19 +166,22 @@ def test_identity_map(self, device, height, width, dt): ) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle, config', [ - (90, {'k': 1, 'dims': (-1, -2)}), - (45, None), - (30, None), - (-30, None), - (-45, None), - (-90, {'k': -1, 'dims': (-1, -2)}), - (180, {'k': 2, 'dims': (-1, -2)}), - ]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize( + "angle, config", + [ + (90, {"k": 1, "dims": (-1, -2)}), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, {"k": -1, "dims": (-1, -2)}), + (180, {"k": 2, "dims": (-1, -2)}), + ], + ) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_square_rotations(self, device, height, width, dt, angle, config, fn): # 2) Test rotation tensor, pil_img = _create_data(height, width, device=device) @@ -183,9 +198,7 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device) - out_tensor = fn( - tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST - ) + out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) if config is not None: assert_equal(torch.rot90(tensor, **config), out_tensor) @@ -199,11 +212,11 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle', [90, 45, 15, -30, -60, -120]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120]) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_rect_rotations(self, device, height, width, dt, angle, fn): # Tests on rectangular images tensor, pil_img = _create_data(height, width, device=device) @@ -220,9 +233,7 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn): ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - out_tensor = fn( - tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST - ).cpu() + out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -234,11 +245,11 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn): angle, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('t', [[10, 12], (-12, -13)]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("t", [[10, 12], (-12, -13)]) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_translations(self, device, height, width, dt, t, fn): # 3) Test translation tensor, pil_img = _create_data(height, width, device=device) @@ -259,22 +270,41 @@ def test_translations(self, device, height, width, dt, t, fn): _assert_equal_tensor_to_pil(out_tensor, out_pil_img) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('a, t, s, sh, f', [ - (45.5, [5, 6], 1.0, [0.0, 0.0], None), - (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), - (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), - (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), - (85, (10, -10), 0.7, [0.0, 0.0], [1, ]), - (0, [0, 0], 1.0, [35.0, ], (2.0, )), - (-25, [0, 0], 1.2, [0.0, 15.0], None), - (-45, [-10, 0], 0.7, [2.0, 5.0], None), - (-45, [-10, -10], 1.2, [4.0, 5.0], None), - (-90, [0, 0], 1.0, [0.0, 0.0], None), - ]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize( + "a, t, s, sh, f", + [ + (45.5, [5, 6], 1.0, [0.0, 0.0], None), + (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), + (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), + (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), + ( + 85, + (10, -10), + 0.7, + [0.0, 0.0], + [ + 1, + ], + ), + ( + 0, + [0, 0], + 1.0, + [ + 35.0, + ], + (2.0,), + ), + (-25, [0, 0], 1.2, [0.0, 15.0], None), + (-45, [-10, 0], 0.7, [2.0, 5.0], None), + (-45, [-10, -10], 1.2, [4.0, 5.0], None), + (-90, [0, 0], 1.0, [0.0, 0.0], None), + ], + ) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): # 4) Test rotation + translation + scale + shear tensor, pil_img = _create_data(height, width, device=device) @@ -303,8 +333,8 @@ def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): (NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_batches(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -314,11 +344,9 @@ def test_batches(self, device, dt): if dt is not None: batch_tensors = batch_tensors.to(dtype=dt) - _test_fn_on_batch( - batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] - ) + _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) - @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu()) def test_warnings(self, device): tensor, pil_img = _create_data(26, 26, device=device) @@ -360,18 +388,27 @@ def _get_data_dims_and_points_for_perspective(): n = 10 for dim in data_dims: - points += [ - (dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) - for i in range(n) - ] + points += [(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) for i in range(n)] return dims_and_points -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, ))) -@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "fill", + ( + None, + [0, 0, 0], + [1, 2, 3], + [255, 255, 255], + [ + 1, + ], + (2.0,), + ), +) +@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)]) def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): if dt == torch.float16 and device == "cpu": @@ -386,8 +423,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): interpolation = NEAREST fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill - out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, - fill=fill_pil) + out_pil_img = F.perspective( + pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill_pil + ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu() @@ -400,9 +438,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): assert ratio_diff_pixels < 0.05 -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) def test_perspective_batch(device, dims_and_points, dt): if dt == torch.float16 and device == "cpu": @@ -419,8 +457,12 @@ def test_perspective_batch(device, dims_and_points, dt): # the border may be entirely different due to small rounding errors. scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 _test_fn_on_batch( - batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, - startpoints=spoints, endpoints=epoints, interpolation=NEAREST + batch_tensors, + F.perspective, + scripted_fn_atol=scripted_fn_atol, + startpoints=spoints, + endpoints=epoints, + interpolation=NEAREST, ) @@ -435,11 +477,23 @@ def test_perspective_interpolation_warning(): assert_equal(res1, res2) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]) -@pytest.mark.parametrize('max_size', [None, 34, 40, 1000]) -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "size", + [ + 32, + 26, + [ + 32, + ], + [32, 32], + (32, 32), + [26, 35], + ], +) +@pytest.mark.parametrize("max_size", [None, 34, 40, 1000]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) def test_resize(device, dt, size, max_size, interpolation): if dt == torch.float16 and device == "cpu": @@ -464,7 +518,9 @@ def test_resize(device, dt, size, max_size, interpolation): assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] - if interpolation not in [NEAREST, ]: + if interpolation not in [ + NEAREST, + ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] @@ -477,21 +533,19 @@ def test_resize(device, dt, size, max_size, interpolation): _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) if isinstance(size, int): - script_size = [size, ] + script_size = [ + size, + ] else: script_size = size - resize_result = script_fn( - tensor, size=script_size, interpolation=interpolation, max_size=max_size - ) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) assert_equal(resized_tensor, resize_result) - _test_fn_on_batch( - batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size - ) + _test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_resize_asserts(device): tensor, pil_img = _create_data(26, 36, device=device) @@ -511,10 +565,10 @@ def test_resize_asserts(device): F.resize(img, size=32, max_size=32) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) def test_resize_antialias(device, dt, size, interpolation): if dt == torch.float16 and device == "cpu": @@ -539,9 +593,7 @@ def test_resize_antialias(device, dt, size, interpolation): if resized_tensor_f.dtype == torch.uint8: resized_tensor_f = resized_tensor_f.to(torch.float) - _assert_approx_equal_tensor_to_pil( - resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" - ) + _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}") accepted_tol = 1.0 + 1e-5 if interpolation == BICUBIC: @@ -552,12 +604,13 @@ def test_resize_antialias(device, dt, size, interpolation): accepted_tol = 15.0 _assert_approx_equal_tensor_to_pil( - resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", - msg=f"{size}, {interpolation}, {dt}" + resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}" ) if isinstance(size, int): - script_size = [size, ] + script_size = [ + size, + ] else: script_size = size @@ -566,7 +619,7 @@ def test_resize_antialias(device, dt, size, interpolation): @needs_cuda -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) def test_assert_resize_antialias(interpolation): # Checks implementation on very large scales @@ -613,9 +666,9 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) def test_adjust_brightness(device, dtype, config): check_functional_vs_PIL_vs_scripted( F.adjust_brightness, @@ -627,23 +680,16 @@ def test_adjust_brightness(device, dtype, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) def test_invert(device, dtype): check_functional_vs_PIL_vs_scripted( - F.invert, - F_pil.invert, - F_t.invert, - {}, - device, - dtype, - tol=1.0, - agg_method="max" + F.invert, F_pil.invert, F_t.invert, {}, device, dtype, tol=1.0, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)]) def test_posterize(device, config): check_functional_vs_PIL_vs_scripted( F.posterize, @@ -657,8 +703,8 @@ def test_posterize(device, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) def test_solarize1(device, config): check_functional_vs_PIL_vs_scripted( F.solarize, @@ -672,9 +718,9 @@ def test_solarize1(device, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) def test_solarize2(device, dtype, config): check_functional_vs_PIL_vs_scripted( F.solarize, @@ -688,9 +734,9 @@ def test_solarize2(device, dtype, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) def test_adjust_sharpness(device, dtype, config): check_functional_vs_PIL_vs_scripted( F.adjust_sharpness, @@ -702,22 +748,15 @@ def test_adjust_sharpness(device, dtype, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) def test_autocontrast(device, dtype): check_functional_vs_PIL_vs_scripted( - F.autocontrast, - F_pil.autocontrast, - F_t.autocontrast, - {}, - device, - dtype, - tol=1.0, - agg_method="max" + F.autocontrast, F_pil.autocontrast, F_t.autocontrast, {}, device, dtype, tol=1.0, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_equalize(device): torch.set_deterministic(False) check_functional_vs_PIL_vs_scripted( @@ -732,53 +771,36 @@ def test_equalize(device): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) def test_adjust_contrast(device, dtype, config): check_functional_vs_PIL_vs_scripted( - F.adjust_contrast, - F_pil.adjust_contrast, - F_t.adjust_contrast, - config, - device, - dtype + F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, dtype ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) def test_adjust_saturation(device, dtype, config): check_functional_vs_PIL_vs_scripted( - F.adjust_saturation, - F_pil.adjust_saturation, - F_t.adjust_saturation, - config, - device, - dtype + F.adjust_saturation, F_pil.adjust_saturation, F_t.adjust_saturation, config, device, dtype ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) def test_adjust_hue(device, dtype, config): check_functional_vs_PIL_vs_scripted( - F.adjust_hue, - F_pil.adjust_hue, - F_t.adjust_hue, - config, - device, - dtype, - tol=16.1, - agg_method="max" + F.adjust_hue, F_pil.adjust_hue, F_t.adjust_hue, config, device, dtype, tol=16.1, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) def test_adjust_gamma(device, dtype, config): check_functional_vs_PIL_vs_scripted( F.adjust_gamma, @@ -790,17 +812,31 @@ def test_adjust_gamma(device, dtype, config): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('pad', [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]) -@pytest.mark.parametrize('config', [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, - {"padding_mode": "edge"}, - {"padding_mode": "reflect"}, - {"padding_mode": "symmetric"}, -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "pad", + [ + 2, + [ + 3, + ], + [0, 3], + (3, 3), + [4, 2, 4, 3], + ], +) +@pytest.mark.parametrize( + "config", + [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + {"padding_mode": "symmetric"}, + ], +) def test_pad(device, dt, pad, config): script_fn = torch.jit.script(F.pad) tensor, pil_img = _create_data(7, 8, device=device) @@ -826,7 +862,9 @@ def test_pad(device, dt, pad, config): _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) if isinstance(pad, int): - script_pad = [pad, ] + script_pad = [ + pad, + ] else: script_pad = pad pad_tensor_script = script_fn(tensor, script_pad, **config) @@ -835,8 +873,8 @@ def test_pad(device, dt, pad, config): _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('mode', [NEAREST, BILINEAR, BICUBIC]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC]) def test_resized_crop(device, mode): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity @@ -861,20 +899,49 @@ def test_resized_crop(device, mode): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('func, args', [ - (F_t._get_image_size, ()), (F_t.vflip, ()), - (F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), - (F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )), - (F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )), - (F_t.center_crop, ([10, 11], )), (F_t.five_crop, ([10, 11], )), - (F_t.ten_crop, ([10, 11], )), (F_t.pad, ([2, ], 2, "constant")), - (F_t.resize, ([10, 11], )), (F_t.perspective, ([0.2, ])), - (F_t.gaussian_blur, ((2, 2), (0.7, 0.5))), - (F_t.invert, ()), (F_t.posterize, (0, )), - (F_t.solarize, (0.3, )), (F_t.adjust_sharpness, (0.3, )), - (F_t.autocontrast, ()), (F_t.equalize, ()) -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "func, args", + [ + (F_t._get_image_size, ()), + (F_t.vflip, ()), + (F_t.hflip, ()), + (F_t.crop, (1, 2, 4, 5)), + (F_t.adjust_brightness, (0.0,)), + (F_t.adjust_contrast, (1.0,)), + (F_t.adjust_hue, (-0.5,)), + (F_t.adjust_saturation, (2.0,)), + (F_t.center_crop, ([10, 11],)), + (F_t.five_crop, ([10, 11],)), + (F_t.ten_crop, ([10, 11],)), + ( + F_t.pad, + ( + [ + 2, + ], + 2, + "constant", + ), + ), + (F_t.resize, ([10, 11],)), + ( + F_t.perspective, + ( + [ + 0.2, + ] + ), + ), + (F_t.gaussian_blur, ((2, 2), (0.7, 0.5))), + (F_t.invert, ()), + (F_t.posterize, (0,)), + (F_t.solarize, (0.3,)), + (F_t.adjust_sharpness, (0.3,)), + (F_t.autocontrast, ()), + (F_t.equalize, ()), + ], +) def test_assert_image_tensor(device, func, args): shape = (100,) tensor = torch.rand(*shape, dtype=torch.float, device=device) @@ -882,7 +949,7 @@ def test_assert_image_tensor(device, func, args): func(tensor, *args) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_vflip(device): script_vflip = torch.jit.script(F.vflip) @@ -899,7 +966,7 @@ def test_vflip(device): _test_fn_on_batch(batch_tensors, F.vflip) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_hflip(device): script_hflip = torch.jit.script(F.hflip) @@ -916,13 +983,16 @@ def test_hflip(device): _test_fn_on_batch(batch_tensors, F.hflip) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('top, left, height, width', [ - (1, 2, 4, 5), # crop inside top-left corner - (2, 12, 3, 4), # crop inside top-right corner - (8, 3, 5, 6), # crop inside bottom-left corner - (8, 11, 4, 3), # crop inside bottom-right corner -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "top, left, height, width", + [ + (1, 2, 4, 5), # crop inside top-left corner + (2, 12, 3, 4), # crop inside top-right corner + (8, 3, 5, 6), # crop inside bottom-left corner + (8, 11, 4, 3), # crop inside bottom-right corner + ], +) def test_crop(device, top, left, height, width): script_crop = torch.jit.script(F.crop) @@ -940,12 +1010,12 @@ def test_crop(device, top, left, height, width): _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('image_size', ('small', 'large')) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('ksize', [(3, 3), [3, 5], (23, 23)]) -@pytest.mark.parametrize('sigma', [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) -@pytest.mark.parametrize('fn', [F.gaussian_blur, torch.jit.script(F.gaussian_blur)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("image_size", ("small", "large")) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) +@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) +@pytest.mark.parametrize("fn", [F.gaussian_blur, torch.jit.script(F.gaussian_blur)]) def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): # true_cv2_results = { @@ -962,17 +1032,15 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7) # "23_23_1.7": ... # } - p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt') + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") true_cv2_results = torch.load(p) - if image_size == 'small': - tensor = torch.from_numpy( - np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) - ).permute(2, 0, 1).to(device) + if image_size == "small": + tensor = ( + torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) + ) else: - tensor = torch.from_numpy( - np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28)) - ).to(device) + tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device) if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -984,22 +1052,19 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize _sigma = sigma[0] if sigma is not None else None shape = tensor.shape - gt_key = "{}_{}_{}__{}_{}_{}".format( - shape[-2], shape[-1], shape[-3], - _ksize[0], _ksize[1], _sigma - ) + gt_key = "{}_{}_{}__{}_{}_{}".format(shape[-2], shape[-1], shape[-3], _ksize[0], _ksize[1], _sigma) if gt_key not in true_cv2_results: return - true_out = torch.tensor( - true_cv2_results[gt_key] - ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) + true_out = ( + torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) + ) out = fn(tensor, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma)) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_hsv2rgb(device): scripted_fn = torch.jit.script(F_t._hsv2rgb) shape = (3, 100, 150) @@ -1008,7 +1073,11 @@ def test_hsv2rgb(device): rgb_img = F_t._hsv2rgb(hsv_img) ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1) - h, s, v, = hsv_img.unbind(0) + ( + h, + s, + v, + ) = hsv_img.unbind(0) h = h.flatten().cpu().numpy() s = s.flatten().cpu().numpy() v = v.flatten().cpu().numpy() @@ -1026,7 +1095,7 @@ def test_hsv2rgb(device): _test_fn_on_batch(batch_tensors, F_t._hsv2rgb) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_rgb2hsv(device): scripted_fn = torch.jit.script(F_t._rgb2hsv) shape = (3, 150, 100) @@ -1035,7 +1104,11 @@ def test_rgb2hsv(device): hsv_img = F_t._rgb2hsv(rgb_img) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) - r, g, b, = rgb_img.unbind(dim=-3) + ( + r, + g, + b, + ) = rgb_img.unbind(dim=-3) r = r.flatten().cpu().numpy() g = g.flatten().cpu().numpy() b = b.flatten().cpu().numpy() @@ -1061,8 +1134,8 @@ def test_rgb2hsv(device): _test_fn_on_batch(batch_tensors, F_t._rgb2hsv) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('num_output_channels', (3, 1)) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("num_output_channels", (3, 1)) def test_rgb_to_grayscale(device, num_output_channels): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) @@ -1080,7 +1153,7 @@ def test_rgb_to_grayscale(device, num_output_channels): _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_center_crop(device): script_center_crop = torch.jit.script(F.center_crop) @@ -1098,7 +1171,7 @@ def test_center_crop(device): _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_five_crop(device): script_five_crop = torch.jit.script(F.five_crop) @@ -1132,7 +1205,7 @@ def test_five_crop(device): assert_equal(transformed_batch, s_transformed_batch) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_ten_crop(device): script_ten_crop = torch.jit.script(F.ten_crop) @@ -1166,5 +1239,5 @@ def test_ten_crop(device): assert_equal(transformed_batch, s_transformed_batch) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_hub.py b/test/test_hub.py index 9c9e417933e..c035d12d07f 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -1,10 +1,12 @@ -import torch.hub as hub -import tempfile -import shutil import os +import shutil import sys +import tempfile + import pytest +import torch.hub as hub + def sum_of_model_parameters(model): s = 0 @@ -16,8 +18,7 @@ def sum_of_model_parameters(model): SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625 -@pytest.mark.skipif('torchvision' in sys.modules, - reason='TestHub must start without torchvision imported') +@pytest.mark.skipif("torchvision" in sys.modules, reason="TestHub must start without torchvision imported") class TestHub: # Only run this check ONCE before all tests start. # - If torchvision is imported before all tests start, e.g. we might find _C.so @@ -26,28 +27,20 @@ class TestHub: # Python cache as we run all hub tests in the same python process. def test_load_from_github(self): - hub_model = hub.load( - 'pytorch/vision', - 'resnet18', - pretrained=True, - progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) def test_set_dir(self): temp_dir = tempfile.gettempdir() hub.set_dir(temp_dir) - hub_model = hub.load( - 'pytorch/vision', - 'resnet18', - pretrained=True, - progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) - assert os.path.exists(temp_dir + '/pytorch_vision_master') - shutil.rmtree(temp_dir + '/pytorch_vision_master') + assert os.path.exists(temp_dir + "/pytorch_vision_master") + shutil.rmtree(temp_dir + "/pytorch_vision_master") def test_list_entrypoints(self): - entry_lists = hub.list('pytorch/vision', force_reload=True) - assert 'resnet18' in entry_lists + entry_lists = hub.list("pytorch/vision", force_reload=True) + assert "resnet18" in entry_lists if __name__ == "__main__": diff --git a/test/test_image.py b/test/test_image.py index e7e5b8b197d..4fa5d93a11a 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -4,24 +4,35 @@ import sys from pathlib import Path -import pytest import numpy as np -import torch +import pytest from PIL import Image, __version__ as PILLOW_VERSION -import torchvision.transforms.functional as F -from common_utils import get_tmp_dir, needs_cuda, assert_equal +import torch +import torchvision.transforms.functional as F from torchvision.io.image import ( - decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png, write_file, ImageReadMode, read_image) + ImageReadMode, + decode_image, + decode_jpeg, + decode_png, + encode_jpeg, + encode_png, + read_file, + read_image, + write_file, + write_jpeg, + write_png, +) + +from common_utils import assert_equal, get_tmp_dir, needs_cuda IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") -DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') +DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg") ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") -IS_WINDOWS = sys.platform in ('win32', 'cygwin') -PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) +IS_WINDOWS = sys.platform in ("win32", "cygwin") +PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) def _get_safe_image_name(name): @@ -34,9 +45,9 @@ def _get_safe_image_name(name): def get_images(directory, img_ext): assert os.path.isdir(directory) - image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True) + image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True) for path in image_paths: - if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']: + if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]: yield path @@ -53,15 +64,18 @@ def normalize_dimensions(img_pil): return img_pil -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(IMAGE_ROOT, ".jpg") -]) -@pytest.mark.parametrize('pil_mode, mode', [ - (None, ImageReadMode.UNCHANGED), - ("L", ImageReadMode.GRAY), - ("RGB", ImageReadMode.RGB), -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize( + "pil_mode, mode", + [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("RGB", ImageReadMode.RGB), + ], +) def test_decode_jpeg(img_path, pil_mode, mode): with Image.open(img_path) as img: @@ -99,18 +113,21 @@ def test_decode_jpeg_errors(): def test_decode_bad_huffman_images(): # sanity check: make sure we can decode the bad Huffman encoding - bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) + bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg")) decode_jpeg(bad_huff) -@pytest.mark.parametrize('img_path', [ - pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) - for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) -]) +@pytest.mark.parametrize( + "img_path", + [ + pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) + for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg")) + ], +) def test_damaged_corrupt_images(img_path): # Truncated images should raise an exception data = read_file(img_path) - if 'corrupt34' in img_path: + if "corrupt34" in img_path: match_message = "Image is incomplete or truncated" else: match_message = "Unsupported marker type" @@ -118,17 +135,20 @@ def test_damaged_corrupt_images(img_path): decode_jpeg(data) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(FAKEDATA_DIR, ".png") -]) -@pytest.mark.parametrize('pil_mode, mode', [ - (None, ImageReadMode.UNCHANGED), - ("L", ImageReadMode.GRAY), - ("LA", ImageReadMode.GRAY_ALPHA), - ("RGB", ImageReadMode.RGB), - ("RGBA", ImageReadMode.RGB_ALPHA), -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")], +) +@pytest.mark.parametrize( + "pil_mode, mode", + [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("LA", ImageReadMode.GRAY_ALPHA), + ("RGB", ImageReadMode.RGB), + ("RGBA", ImageReadMode.RGB_ALPHA), + ], +) def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: @@ -159,10 +179,10 @@ def test_decode_png_errors(): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(IMAGE_DIR, ".png") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], +) def test_encode_png(img_path): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) @@ -181,21 +201,19 @@ def test_encode_png_errors(): encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=-1) + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1) with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=10) + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10) with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(IMAGE_DIR, ".png") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], +) def test_write_png(img_path): with get_tmp_dir() as d: pil_image = Image.open(img_path) @@ -203,7 +221,7 @@ def test_write_png(img_path): img_pil = img_pil.permute(2, 0, 1) filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) + torch_png = os.path.join(d, "{0}_torch.png".format(filename)) write_png(img_pil, torch_png, compression_level=6) saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = saved_image.permute(2, 0, 1) @@ -213,9 +231,9 @@ def test_write_png(img_path): def test_read_file(): with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' + fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: + with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) @@ -224,14 +242,14 @@ def test_read_file(): assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): - read_file('tst') + read_file("tst") def test_read_file_non_ascii(): with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fname, content = "日本語(Japanese).bin", b"TorchVision\211\n" fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: + with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) @@ -242,12 +260,12 @@ def test_read_file_non_ascii(): def test_write_file(): with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' + fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(d, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: saved_content = f.read() os.unlink(fpath) assert content == saved_content @@ -255,25 +273,28 @@ def test_write_file(): def test_write_file_non_ascii(): with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fname, content = "日本語(Japanese).bin", b"TorchVision\211\n" fpath = os.path.join(d, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: saved_content = f.read() os.unlink(fpath) assert content == saved_content -@pytest.mark.parametrize('shape', [ - (27, 27), - (60, 60), - (105, 105), -]) +@pytest.mark.parametrize( + "shape", + [ + (27, 27), + (60, 60), + (105, 105), + ], +) def test_read_1_bit_png(shape): with get_tmp_dir() as root: - image_path = os.path.join(root, f'test_{shape}.png') + image_path = os.path.join(root, f"test_{shape}.png") pixels = np.random.rand(*shape) > 0.5 img = Image.fromarray(pixels) img.save(image_path) @@ -282,18 +303,24 @@ def test_read_1_bit_png(shape): assert_equal(img1, img2) -@pytest.mark.parametrize('shape', [ - (27, 27), - (60, 60), - (105, 105), -]) -@pytest.mark.parametrize('mode', [ - ImageReadMode.UNCHANGED, - ImageReadMode.GRAY, -]) +@pytest.mark.parametrize( + "shape", + [ + (27, 27), + (60, 60), + (105, 105), + ], +) +@pytest.mark.parametrize( + "mode", + [ + ImageReadMode.UNCHANGED, + ImageReadMode.GRAY, + ], +) def test_read_1_bit_png_consistency(shape, mode): with get_tmp_dir() as root: - image_path = os.path.join(root, f'test_{shape}.png') + image_path = os.path.join(root, f"test_{shape}.png") pixels = np.random.rand(*shape) > 0.5 img = Image.fromarray(pixels) img.save(image_path) @@ -303,27 +330,27 @@ def test_read_1_bit_png_consistency(shape, mode): @needs_cuda -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(IMAGE_ROOT, ".jpg") -]) -@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize('scripted', (False, True)) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("scripted", (False, True)) def test_decode_jpeg_cuda(mode, img_path, scripted): - if 'cmyk' in img_path: + if "cmyk" in img_path: pytest.xfail("Decoding a CMYK jpeg isn't supported") data = read_file(img_path) img = decode_image(data, mode=mode) f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg - img_nvjpeg = f(data, mode=mode, device='cuda') + img_nvjpeg = f(data, mode=mode, device="cuda") # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 @needs_cuda -@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) +@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) def test_decode_jpeg_cuda_device_param(cuda_device): """Make sure we can pass a string or a torch.device as device param""" data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) @@ -334,13 +361,13 @@ def test_decode_jpeg_cuda_device_param(cuda_device): def test_decode_jpeg_cuda_errors(): data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_jpeg(data.reshape(-1, 1), device='cuda') + decode_jpeg(data.reshape(-1, 1), device="cuda") with pytest.raises(RuntimeError, match="input tensor must be on CPU"): - decode_jpeg(data.to('cuda'), device='cuda') + decode_jpeg(data.to("cuda"), device="cuda") with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): - decode_jpeg(data.to(torch.float), device='cuda') + decode_jpeg(data.to(torch.float), device="cuda") with pytest.raises(RuntimeError, match="Expected a cuda device"): - torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') + torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu") def test_encode_jpeg_errors(): @@ -348,12 +375,10 @@ def test_encode_jpeg_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) - with pytest.raises(ValueError, match="Image quality should be a positive number " - "between 1 and 100"): + with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) - with pytest.raises(ValueError, match="Image quality should be a positive number " - "between 1 and 100"): + with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): @@ -374,14 +399,15 @@ def _inner(test_func): return test_func else: return pytest.mark.dont_collect(test_func) + return _inner @_collect_if(cond=IS_WINDOWS) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_encode_jpeg_reference(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it @@ -395,12 +421,11 @@ def test_encode_jpeg_reference(img_path): # FIXME: make the correct tests pass on windows and remove this. dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) - write_folder = os.path.join(dirname, 'jpeg_write') - expected_file = os.path.join( - write_folder, '{0}_pil.jpg'.format(filename)) + write_folder = os.path.join(dirname, "jpeg_write") + expected_file = os.path.join(write_folder, "{0}_pil.jpg".format(filename)) img = decode_jpeg(read_file(img_path)) - with open(expected_file, 'rb') as f: + with open(expected_file, "rb") as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: @@ -410,10 +435,10 @@ def test_encode_jpeg_reference(img_path): @_collect_if(cond=IS_WINDOWS) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_write_jpeg_reference(img_path): # FIXME: Remove this eventually, see test_encode_jpeg_reference with get_tmp_dir() as d: @@ -422,35 +447,31 @@ def test_write_jpeg_reference(img_path): basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - d, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + torch_jpeg = os.path.join(d, "{0}_torch.jpg".format(filename)) + pil_jpeg = os.path.join(basedir, "jpeg_write", "{0}_pil.jpg".format(filename)) write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: + with open(torch_jpeg, "rb") as f: torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: + with open(pil_jpeg, "rb") as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes) -@pytest.mark.skipif(IS_WINDOWS, reason=( - 'this test fails on windows because PIL uses libjpeg-turbo on windows' -)) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows")) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_encode_jpeg(img_path): img = read_image(img_path) pil_img = F.to_pil_image(img) buf = io.BytesIO() - pil_img.save(buf, format='JPEG', quality=75) + pil_img.save(buf, format="JPEG", quality=75) # pytorch can't read from raw bytes so we go through numpy pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) @@ -461,29 +482,27 @@ def test_encode_jpeg(img_path): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) -@pytest.mark.skipif(IS_WINDOWS, reason=( - 'this test fails on windows because PIL uses libjpeg-turbo on windows' -)) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows")) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_write_jpeg(img_path): with get_tmp_dir() as d: d = Path(d) img = read_image(img_path) pil_img = F.to_pil_image(img) - torch_jpeg = str(d / 'torch.jpg') - pil_jpeg = str(d / 'pil.jpg') + torch_jpeg = str(d / "torch.jpg") + pil_jpeg = str(d / "pil.jpg") write_jpeg(img, torch_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: + with open(torch_jpeg, "rb") as f: torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: + with open(pil_jpeg, "rb") as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes) diff --git a/test/test_internet.py b/test/test_internet.py index 772379a2289..c7e79345bb6 100644 --- a/test/test_internet.py +++ b/test/test_internet.py @@ -6,16 +6,16 @@ """ import os -import pytest -import warnings from urllib.error import URLError +import pytest + import torchvision.datasets.utils as utils + from common_utils import get_tmp_dir class TestDatasetUtils: - def test_get_redirect_url(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" @@ -59,12 +59,12 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker): filename = "filename" md5 = "md5" - mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') + mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive") with get_tmp_dir() as root: utils.download_url(url, root, filename, md5) mocked.assert_called_once_with(id, root, filename, md5) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_io.py b/test/test_io.py index 56cd0af5fd8..245bf8e6cc5 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,19 +1,19 @@ -import pytest -import os import contextlib +import os import sys import tempfile + +import pytest + import torch import torchvision.io as io from torchvision import get_video_backend -import warnings -from urllib.error import URLError - -from common_utils import get_tmp_dir, assert_equal +from common_utils import assert_equal, get_tmp_dir try: import av + # Do a version test too io.video._check_av_available() except ImportError: @@ -42,29 +42,30 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, raise ValueError("video_codec can't be specified together with lossless") if options is not None: raise ValueError("options can't be specified together with lossless") - video_codec = 'libx264rgb' - options = {'crf': '0'} + video_codec = "libx264rgb" + options = {"crf": "0"} if video_codec is None: if get_video_backend() == "pyav": - video_codec = 'libx264' + video_codec = "libx264" else: # when video_codec is not set, we assume it is libx264rgb which accepts # RGB pixel formats as input instead of YUV - video_codec = 'libx264rgb' + video_codec = "libx264rgb" if options is None: options = {} data = _create_video_frames(num_frames, height, width) - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: f.close() io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) yield f.name, data os.unlink(f.name) -@pytest.mark.skipif(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, - reason="video_reader backend not available") +@pytest.mark.skipif( + get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, reason="video_reader backend not available" +) @pytest.mark.skipif(av is None, reason="PyAV unavailable") class TestVideo: # compression adds artifacts, thus we add a tolerance of @@ -107,14 +108,14 @@ def test_read_timestamps(self): assert pts == expected_pts - @pytest.mark.parametrize('start', range(5)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(5)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video(self, start, offset): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start:(start + offset)] + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv) @@ -125,22 +126,22 @@ def test_read_partial_video(self, start, offset): assert len(lv) == 4 assert_equal(data[4:8], lv) - @pytest.mark.parametrize('start', range(0, 80, 20)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(0, 80, 20)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video_bframes(self, start, offset): # do not use lossless encoding, to test the presence of B-frames - options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} + options = {"bframes": "16", "keyint": "10", "min-keyint": "4"} with temp_video(100, 300, 300, 5, options=options) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start:(start + offset)] + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE) lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) # TODO fix this - if get_video_backend() == 'pyav': + if get_video_backend() == "pyav": assert len(lv) == 4 assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE) else: @@ -156,7 +157,7 @@ def test_read_packed_b_frames_divx_file(self): assert fps == 30 def test_read_timestamps_from_packet(self): - with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data): + with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data): pts, _ = io.read_video_timestamps(f_name) # note: not all formats/codecs provide accurate information for computing the # timestamps. For the format that we use here, this information is available, @@ -164,7 +165,7 @@ def test_read_timestamps_from_packet(self): with av.open(f_name) as container: stream = container.streams[0] # make sure we went through the optimized codepath - assert b'Lavc' in stream.codec_context.extradata + assert b"Lavc" in stream.codec_context.extradata pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) expected_pts = [i * pts_step for i in range(num_frames)] @@ -173,7 +174,7 @@ def test_read_timestamps_from_packet(self): def test_read_video_pts_unit_sec(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name, pts_unit='sec') + lv, _, info = io.read_video(f_name, pts_unit="sec") assert_equal(data, lv) assert info["video_fps"] == 5 @@ -181,7 +182,7 @@ def test_read_video_pts_unit_sec(self): def test_read_timestamps_pts_unit_sec(self): with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") with av.open(f_name) as container: stream = container.streams[0] @@ -191,22 +192,22 @@ def test_read_timestamps_pts_unit_sec(self): assert pts == expected_pts - @pytest.mark.parametrize('start', range(5)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(5)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video_pts_unit_sec(self, start, offset): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec') - s_data = data[start:(start + offset)] + lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec") + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv) with av.open(f_name) as container: stream = container.streams[0] - lv, _, _ = io.read_video(f_name, - int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], - pts_unit='sec') + lv, _, _ = io.read_video( + f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec" + ) if get_video_backend() == "pyav": # for "video_reader" backend, we don't decode the closest early frame # when the given start pts is not matching any frame pts @@ -214,8 +215,8 @@ def test_read_partial_video_pts_unit_sec(self, start, offset): assert_equal(data[4:8], lv) def test_read_video_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - f.write(b'This is not an mpg4 file') + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"This is not an mpg4 file") video, audio, info = io.read_video(f.name) assert isinstance(video, torch.Tensor) assert isinstance(audio, torch.Tensor) @@ -224,8 +225,8 @@ def test_read_video_corrupted_file(self): assert info == {} def test_read_video_timestamps_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - f.write(b'This is not an mpg4 file') + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"This is not an mpg4 file") video_pts, video_fps = io.read_video_timestamps(f.name) assert video_pts == [] assert video_fps is None @@ -233,18 +234,18 @@ def test_read_video_timestamps_corrupted_file(self): @pytest.mark.skip(reason="Temporarily disabled due to new pyav") def test_read_video_partially_corrupted_file(self): with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): - with open(f_name, 'r+b') as f: + with open(f_name, "r+b") as f: size = os.path.getsize(f_name) bytes_to_overwrite = size // 10 # seek to the middle of the file f.seek(5 * bytes_to_overwrite) # corrupt 10% of the file from the middle - f.write(b'\xff' * bytes_to_overwrite) + f.write(b"\xff" * bytes_to_overwrite) # this exercises the container.decode assertion check - video, audio, info = io.read_video(f.name, pts_unit='sec') + video, audio, info = io.read_video(f.name, pts_unit="sec") # check that size is not equal to 5, but 3 # TODO fix this - if get_video_backend() == 'pyav': + if get_video_backend() == "pyav": assert len(video) == 3 else: assert len(video) == 4 @@ -254,7 +255,7 @@ def test_read_video_partially_corrupted_file(self): with pytest.raises(AssertionError): assert_equal(video, data) - @pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows') + @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") def test_write_video_with_audio(self): f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") @@ -266,15 +267,13 @@ def test_write_video_with_audio(self): video_tensor, round(info["video_fps"]), video_codec="libx264rgb", - options={'crf': '0'}, + options={"crf": "0"}, audio_array=audio_tensor, audio_fps=info["audio_fps"], audio_codec="aac", ) - out_video_tensor, out_audio_tensor, out_info = io.read_video( - out_f_name, pts_unit="sec" - ) + out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec") assert info["video_fps"] == out_info["video_fps"] assert_equal(video_tensor, out_video_tensor) @@ -290,5 +289,5 @@ def test_write_video_with_audio(self): # TODO add tests for audio -if __name__ == '__main__': +if __name__ == "__main__": pytest.main(__file__) diff --git a/test/test_io_opt.py b/test/test_io_opt.py index 87698b34624..c6273b94de8 100644 --- a/test/test_io_opt.py +++ b/test/test_io_opt.py @@ -1,12 +1,13 @@ import unittest -from torchvision import set_video_backend + import test_io +from torchvision import set_video_backend # noqa: F401 # Disabling the video backend switching temporarily # set_video_backend('video_reader') -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(test_io) unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/test/test_models.py b/test/test_models.py index 72ae68f5615..520b66fd885 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,22 +1,23 @@ -import os -import io -import sys -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda -from _utils_internal import get_relative_path -from collections import OrderedDict import functools +import io import operator +import os +import traceback +import warnings +from collections import OrderedDict + +import pytest + import torch import torch.fx import torch.nn as nn import torchvision from torchvision import models -import pytest -import warnings -import traceback +from _utils_internal import get_relative_path +from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed -ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' +ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" def get_available_classification_models(): @@ -50,7 +51,7 @@ def _get_expected_file(name=None): # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names # We hardcode it here to avoid having to re-generate the reference files - expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name) + expected_file = expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name) expected_file += "_expect.pkl" if not ACCEPT and not os.path.exists(expected_file): @@ -92,6 +93,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): def assert_export_import_module(m, args): """Check that the results of a model are the same after saving and loading""" + def get_export_import_copy(m): """Save and load a TorchScript model""" buffer = io.BytesIO() @@ -115,15 +117,17 @@ def get_export_import_copy(m): if a is not None: torch.testing.assert_close(a, b, atol=tol, rtol=tol) - TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' + TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1" if not TEST_WITH_SLOW or skip: # TorchScript is not enabled, skip these tests - msg = "The check_jit_scriptable test for {} was skipped. " \ - "This test checks if the module's results in TorchScript " \ - "match eager and that it can be exported. To run these " \ - "tests make sure you set the environment variable " \ - "PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \ - "manually skipped.".format(nn_module.__class__.__name__) + msg = ( + "The check_jit_scriptable test for {} was skipped. " + "This test checks if the module's results in TorchScript " + "match eager and that it can be exported. To run these " + "tests make sure you set the environment variable " + "PYTORCH_TEST_WITH_SLOW=1 and that the test is not " + "manually skipped.".format(nn_module.__class__.__name__) + ) warnings.warn(msg, RuntimeWarning) return None @@ -152,8 +156,8 @@ def _check_fx_compatible(model, inputs): # before they are compared to the eager model outputs. This is useful if the # model outputs are different between TorchScript / Eager mode script_model_unwrapper = { - 'googlenet': lambda x: x.logits, - 'inception_v3': lambda x: x.logits, + "googlenet": lambda x: x.logits, + "inception_v3": lambda x: x.logits, "fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], @@ -192,43 +196,41 @@ def _check_fx_compatible(model, inputs): # The following contains configuration parameters for all models which are used by # the _test_*_model methods. _model_params = { - 'inception_v3': { - 'input_shape': (1, 3, 299, 299) + "inception_v3": {"input_shape": (1, 3, 299, 299)}, + "retinanet_resnet50_fpn": { + "num_classes": 20, + "score_thresh": 0.01, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'retinanet_resnet50_fpn': { - 'num_classes': 20, - 'score_thresh': 0.01, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "keypointrcnn_resnet50_fpn": { + "num_classes": 2, + "min_size": 224, + "max_size": 224, + "box_score_thresh": 0.15, + "input_shape": (3, 224, 224), }, - 'keypointrcnn_resnet50_fpn': { - 'num_classes': 2, - 'min_size': 224, - 'max_size': 224, - 'box_score_thresh': 0.15, - 'input_shape': (3, 224, 224), + "fasterrcnn_resnet50_fpn": { + "num_classes": 20, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'fasterrcnn_resnet50_fpn': { - 'num_classes': 20, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "maskrcnn_resnet50_fpn": { + "num_classes": 10, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'maskrcnn_resnet50_fpn': { - 'num_classes': 10, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "fasterrcnn_mobilenet_v3_large_fpn": { + "box_score_thresh": 0.02076, }, - 'fasterrcnn_mobilenet_v3_large_fpn': { - 'box_score_thresh': 0.02076, + "fasterrcnn_mobilenet_v3_large_320_fpn": { + "box_score_thresh": 0.02076, + "rpn_pre_nms_top_n_test": 1000, + "rpn_post_nms_top_n_test": 1000, }, - 'fasterrcnn_mobilenet_v3_large_320_fpn': { - 'box_score_thresh': 0.02076, - 'rpn_pre_nms_top_n_test': 1000, - 'rpn_post_nms_top_n_test': 1000, - } } @@ -242,7 +244,7 @@ def _make_sliced_model(model, stop_layer): return new_model -@pytest.mark.parametrize('model_name', ['densenet121', 'densenet169', 'densenet201', 'densenet161']) +@pytest.mark.parametrize("model_name", ["densenet121", "densenet169", "densenet201", "densenet161"]) def test_memory_efficient_densenet(model_name): input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) @@ -264,9 +266,9 @@ def test_memory_efficient_densenet(model_name): torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) -@pytest.mark.parametrize('dilate_layer_2', (True, False)) -@pytest.mark.parametrize('dilate_layer_3', (True, False)) -@pytest.mark.parametrize('dilate_layer_4', (True, False)) +@pytest.mark.parametrize("dilate_layer_2", (True, False)) +@pytest.mark.parametrize("dilate_layer_3", (True, False)) +@pytest.mark.parametrize("dilate_layer_4", (True, False)) def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4): # TODO improve tests to also check that each layer has the right dimensionality model = models.__dict__["resnet50"](replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4)) @@ -286,7 +288,7 @@ def test_mobilenet_v2_residual_setting(): assert out.shape[-1] == 1000 -@pytest.mark.parametrize('model_name', ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) +@pytest.mark.parametrize("model_name", ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) def test_mobilenet_norm_layer(model_name): model = models.__dict__[model_name]() assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) @@ -295,16 +297,16 @@ def get_gn(num_channels): return nn.GroupNorm(32, num_channels) model = models.__dict__[model_name](norm_layer=get_gn) - assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) def test_inception_v3_eval(): # replacement for models.inception_v3(pretrained=True) that does not download weights kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["init_weights"] = False name = "inception_v3" model = models.Inception3(**kwargs) model.aux_logits = False @@ -332,9 +334,9 @@ def test_fasterrcnn_double(): def test_googlenet_eval(): # replacement for models.googlenet(pretrained=True) that does not download weights kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["init_weights"] = False name = "googlenet" model = models.GoogLeNet(**kwargs) model.aux_logits = False @@ -357,7 +359,7 @@ def checkOut(out): model.cuda() model.eval() input_shape = (3, 300, 300) - x = torch.rand(input_shape, device='cuda') + x = torch.rand(input_shape, device="cuda") model_input = [x] out = model(model_input) assert model_input[0] is x @@ -383,30 +385,29 @@ def test_generalizedrcnn_transform_repr(): image_mean = [0.485, 0.456, 0.406] image_std = [0.229, 0.224, 0.225] - t = models.detection.transform.GeneralizedRCNNTransform(min_size=min_size, - max_size=max_size, - image_mean=image_mean, - image_std=image_std) + t = models.detection.transform.GeneralizedRCNNTransform( + min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std + ) # Check integrity of object __repr__ attribute - expected_string = 'GeneralizedRCNNTransform(' - _indent = '\n ' - expected_string += '{0}Normalize(mean={1}, std={2})'.format(_indent, image_mean, image_std) - expected_string += '{0}Resize(min_size=({1},), max_size={2}, '.format(_indent, min_size, max_size) + expected_string = "GeneralizedRCNNTransform(" + _indent = "\n " + expected_string += "{0}Normalize(mean={1}, std={2})".format(_indent, image_mean, image_std) + expected_string += "{0}Resize(min_size=({1},), max_size={2}, ".format(_indent, min_size, max_size) expected_string += "mode='bilinear')\n)" assert t.__repr__() == expected_string -@pytest.mark.parametrize('model_name', get_available_classification_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_classification_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 50, - 'input_shape': (1, 3, 224, 224), + "num_classes": 50, + "input_shape": (1, 3, 224, 224), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -427,17 +428,17 @@ def test_classification_model(model_name, dev): assert out.shape[-1] == 50 -@pytest.mark.parametrize('model_name', get_available_segmentation_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_segmentation_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_segmentation_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 10, - 'pretrained_backbone': False, - 'input_shape': (1, 3, 32, 32), + "num_classes": 10, + "pretrained_backbone": False, + "input_shape": (1, 3, 32, 32), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.segmentation.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -476,25 +477,27 @@ def check_out(out): full_validation &= check_out(out) if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(test_segmentation_model.__name__) + msg = ( + "The output of {} could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase.".format(test_segmentation_model.__name__) + ) warnings.warn(msg, RuntimeWarning) pytest.skip(msg) -@pytest.mark.parametrize('model_name', get_available_detection_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_detection_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_detection_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 50, - 'pretrained_backbone': False, - 'input_shape': (3, 300, 300), + "num_classes": 50, + "pretrained_backbone": False, + "input_shape": (3, 300, 300), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.detection.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -522,7 +525,7 @@ def subsample_tensor(tensor): return tensor ith_index = num_elems // num_samples - return tensor[ith_index - 1::ith_index] + return tensor[ith_index - 1 :: ith_index] def compute_mean_std(tensor): # can't compute mean of integral tensor @@ -545,8 +548,9 @@ def compute_mean_std(tensor): # scores. expected_file = _get_expected_file(model_name) expected = torch.load(expected_file) - torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, - check_device=False, check_dtype=False) + torch.testing.assert_close( + output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False + ) # Note: Fmassa proposed turning off NMS by adapting the threshold # and then using the Hungarian algorithm as in DETR to find the @@ -567,15 +571,17 @@ def compute_mean_std(tensor): full_validation &= check_out(out) if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(test_detection_model.__name__) + msg = ( + "The output of {} could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase.".format(test_detection_model.__name__) + ) warnings.warn(msg, RuntimeWarning) pytest.skip(msg) -@pytest.mark.parametrize('model_name', get_available_detection_models()) +@pytest.mark.parametrize("model_name", get_available_detection_models()) def test_detection_model_validation(model_name): set_rng_seed(0) model = models.detection.__dict__[model_name](num_classes=50, pretrained_backbone=False) @@ -587,25 +593,25 @@ def test_detection_model_validation(model_name): model(x) # validate type - targets = [{'boxes': 0.}] + targets = [{"boxes": 0.0}] with pytest.raises(ValueError): model(x, targets=targets) # validate boxes shape for boxes in (torch.rand((4,)), torch.rand((1, 5))): - targets = [{'boxes': boxes}] + targets = [{"boxes": boxes}] with pytest.raises(ValueError): model(x, targets=targets) # validate that no degenerate boxes are present boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) - targets = [{'boxes': boxes}] + targets = [{"boxes": boxes}] with pytest.raises(ValueError): model(x, targets=targets) -@pytest.mark.parametrize('model_name', get_available_video_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_video_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_video_model(model_name, dev): # the default input shape is # bs * num_channels * clip_len * h *w @@ -626,25 +632,29 @@ def test_video_model(model_name, dev): assert out.shape[-1] == 50 -@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and - 'qnnpack' in torch.backends.quantized.supported_engines), - reason="This Pytorch Build has not been built with fbgemm and qnnpack") -@pytest.mark.parametrize('model_name', get_available_quantizable_models()) +@pytest.mark.skipif( + not ( + "fbgemm" in torch.backends.quantized.supported_engines + and "qnnpack" in torch.backends.quantized.supported_engines + ), + reason="This Pytorch Build has not been built with fbgemm and qnnpack", +) +@pytest.mark.parametrize("model_name", get_available_quantizable_models()) def test_quantized_classification_model(model_name): defaults = { - 'input_shape': (1, 3, 224, 224), - 'pretrained': False, - 'quantize': True, + "input_shape": (1, 3, 224, 224), + "pretrained": False, + "quantize": True, } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") # First check if quantize=True provides models that can run with input data model = torchvision.models.quantization.__dict__[model_name](**kwargs) x = torch.rand(input_shape) model(x) - kwargs['quantize'] = False + kwargs["quantize"] = False for eval_mode in [True, False]: model = torchvision.models.quantization.__dict__[model_name](**kwargs) if eval_mode: @@ -670,5 +680,5 @@ def test_quantized_classification_model(model_name): raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 4477e9e1566..eacd7a831c2 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -1,13 +1,18 @@ +import pytest + import torch -from common_utils import TestCase, assert_equal from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.image_list import ImageList -import pytest + +from common_utils import TestCase, assert_equal class Tester(TestCase): def test_incorrect_anchors(self): - incorrect_sizes = ((2, 4, 8), (32, 8), ) + incorrect_sizes = ( + (2, 4, 8), + (32, 8), + ) incorrect_aspects = (0.5, 1.0) anc = AnchorGenerator(incorrect_sizes, incorrect_aspects) image1 = torch.randn(3, 800, 800) @@ -49,15 +54,19 @@ def test_anchor_generator(self): for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()): num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc - anchors_output = torch.tensor([[-5., -5., 5., 5.], - [0., -5., 10., 5.], - [5., -5., 15., 5.], - [-5., 0., 5., 10.], - [0., 0., 10., 10.], - [5., 0., 15., 10.], - [-5., 5., 5., 15.], - [0., 5., 10., 15.], - [5., 5., 15., 15.]]) + anchors_output = torch.tensor( + [ + [-5.0, -5.0, 5.0, 5.0], + [0.0, -5.0, 10.0, 5.0], + [5.0, -5.0, 15.0, 5.0], + [-5.0, 0.0, 5.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + [5.0, 0.0, 15.0, 10.0], + [-5.0, 5.0, 5.0, 15.0], + [0.0, 5.0, 10.0, 15.0], + [5.0, 5.0, 15.0, 15.0], + ] + ) assert num_anchors_estimated == 9 assert len(anchors) == 2 @@ -76,12 +85,14 @@ def test_defaultbox_generator(self): model.eval() dboxes = model(images, features) - dboxes_output = torch.tensor([ - [6.3750, 6.3750, 8.6250, 8.6250], - [4.7443, 4.7443, 10.2557, 10.2557], - [5.9090, 6.7045, 9.0910, 8.2955], - [6.7045, 5.9090, 8.2955, 9.0910] - ]) + dboxes_output = torch.tensor( + [ + [6.3750, 6.3750, 8.6250, 8.6250], + [4.7443, 4.7443, 10.2557, 10.2557], + [5.9090, 6.7045, 9.0910, 8.2955], + [6.7045, 5.9090, 8.2955, 9.0910], + ] + ) assert len(dboxes) == 2 assert tuple(dboxes[0].shape) == (4, 4) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index a4b7064b338..8a4c76d9741 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -1,25 +1,26 @@ -import torch +import pytest +import torch import torchvision.models -from torchvision.ops import MultiScaleRoIAlign -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork -from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from torchvision.models.detection.roi_heads import RoIHeads +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead +from torchvision.ops import MultiScaleRoIAlign -import pytest from common_utils import assert_equal class TestModelsDetectionNegativeSamples: - def _make_empty_sample(self, add_masks=False, add_keypoints=False): images = [torch.rand((3, 100, 100), dtype=torch.float32)] boxes = torch.zeros((0, 4), dtype=torch.float32) - negative_target = {"boxes": boxes, - "labels": torch.zeros(0, dtype=torch.int64), - "image_id": 4, - "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), - "iscrowd": torch.zeros((0,), dtype=torch.int64)} + negative_target = { + "boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64), + } if add_masks: negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8) @@ -36,16 +37,10 @@ def test_targets_to_anchors(self): anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0]) - head = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - 0.5, 0.3, - 256, 0.5, - 2000, 2000, 0.7, 0.05) + head = RegionProposalNetwork(rpn_anchor_generator, rpn_head, 0.5, 0.3, 256, 0.5, 2000, 2000, 0.7, 0.05) labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) @@ -63,29 +58,29 @@ def test_assign_targets_to_proposals(self): gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)] gt_labels = [torch.tensor([[0]], dtype=torch.int64)] - box_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - 4 * resolution ** 2, - representation_size) + box_head = TwoMLPHead(4 * resolution ** 2, representation_size) representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - 2) + box_predictor = FastRCNNPredictor(representation_size, 2) roi_heads = RoIHeads( # Box - box_roi_pool, box_head, box_predictor, - 0.5, 0.5, - 512, 0.25, + box_roi_pool, + box_head, + box_predictor, + 0.5, + 0.5, + 512, + 0.25, None, - 0.05, 0.5, 100) + 0.05, + 0.5, + 100, + ) matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) @@ -97,61 +92,61 @@ def test_assign_targets_to_proposals(self): assert labels[0].shape == torch.Size([proposals[0].shape[0]]) assert labels[0].dtype == torch.int64 - @pytest.mark.parametrize('name', [ - "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", - ]) + @pytest.mark.parametrize( + "name", + [ + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + ], + ) def test_forward_negative_sample_frcnn(self, name): - model = torchvision.models.detection.__dict__[name]( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.__dict__[name](num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) def test_forward_negative_sample_mrcnn(self): - model = torchvision.models.detection.maskrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_masks=True) loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_mask"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_mask"], torch.tensor(0.0)) def test_forward_negative_sample_krcnn(self): - model = torchvision.models.detection.keypointrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.keypointrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_keypoints=True) loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0)) def test_forward_negative_sample_retinanet(self): model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + ) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) + assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) def test_forward_negative_sample_ssd(self): - model = torchvision.models.detection.ssd300_vgg16( - num_classes=2, pretrained_backbone=False) + model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) + assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index b599bbeaea1..29fdf7b860a 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -1,14 +1,15 @@ import copy + +import pytest + import torch -from torchvision.models.detection import _utils +from torchvision.models.detection import _utils, backbone_utils from torchvision.models.detection.transform import GeneralizedRCNNTransform -import pytest -from torchvision.models.detection import backbone_utils + from common_utils import assert_equal class TestModelsDetectionUtils: - def test_balanced_positive_negative_sampler(self): sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25) # keep all 6 negatives first, then add 3 positives, last two are ignore @@ -22,16 +23,13 @@ def test_balanced_positive_negative_sampler(self): assert neg[0].sum() == 3 assert neg[0][0:6].sum() == 3 - @pytest.mark.parametrize('train_layers, exp_froz_params', [ - (0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0) - ]) + @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # we know how many initial layers and parameters of the network should # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - model = backbone_utils.resnet_fpn_backbone( - 'resnet50', pretrained=False, trainable_layers=train_layers) + model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers) # boolean list that is true if the param at that index is frozen is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] # check that expected initial number of layers are frozen @@ -40,34 +38,37 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) + pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3 + ) assert ret == 3 # can't go beyond 5 with pytest.raises(AssertionError): ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) + pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 + ) # if not pretrained, should use all trainable layers and warn with pytest.warns(UserWarning): ret = backbone_utils._validate_trainable_layers( - pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) + pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3 + ) assert ret == 5 def test_transform_copy_targets(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] - targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}] + targets = [{"boxes": torch.rand(3, 4)}, {"boxes": torch.rand(2, 4)}] targets_copy = copy.deepcopy(targets) out = transform(image, targets) # noqa: F841 - assert_equal(targets[0]['boxes'], targets_copy[0]['boxes']) - assert_equal(targets[1]['boxes'], targets_copy[1]['boxes']) + assert_equal(targets[0]["boxes"], targets_copy[0]["boxes"]) + assert_equal(targets[1]["boxes"], targets_copy[1]["boxes"]) def test_not_float_normalize(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] - targets = [{'boxes': torch.rand(3, 4)}] + targets = [{"boxes": torch.rand(3, 4)}] with pytest.raises(TypeError): out = transform(image, targets) # noqa: F841 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_onnx.py b/test/test_onnx.py index c093ccb4863..66d592f45e8 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -1,4 +1,20 @@ # onnxruntime requires python 3.5 or above +import io +from collections import OrderedDict + +import pytest + +import torch +from torchvision import models, ops +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from torchvision.models.detection.image_list import ImageList +from torchvision.models.detection.roi_heads import RoIHeads +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead +from torchvision.models.detection.transform import GeneralizedRCNNTransform +from torchvision.ops._register_onnx_ops import _onnx_opset_version + +from common_utils import assert_equal, set_rng_seed + try: # This import should be before that of torch # see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840 @@ -6,31 +22,31 @@ except ImportError: onnxruntime = None -from common_utils import set_rng_seed, assert_equal -import io -import torch -from torchvision import ops -from torchvision import models -from torchvision.models.detection.image_list import ImageList -from torchvision.models.detection.transform import GeneralizedRCNNTransform -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork -from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead -from collections import OrderedDict - -import pytest -from torchvision.ops._register_onnx_ops import _onnx_opset_version +try: + # This import should be before that of torch + # see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840 + import onnxruntime +except ImportError: + onnxruntime = None -@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable') +@pytest.mark.skipif(onnxruntime is None, reason="ONNX Runtime unavailable") class TestONNXExporter: @classmethod def setup_class(cls): torch.manual_seed(123) - def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, - output_names=None, input_names=None): + def run_model( + self, + model, + inputs_list, + tolerate_small_mismatch=False, + do_constant_folding=True, + dynamic_axes=None, + output_names=None, + input_names=None, + ): model.eval() onnx_io = io.BytesIO() @@ -39,14 +55,20 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta else: torch_onnx_input = inputs_list[0] # export to onnx with the first input - torch.onnx.export(model, torch_onnx_input, onnx_io, - do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, - dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) + torch.onnx.export( + model, + torch_onnx_input, + onnx_io, + do_constant_folding=do_constant_folding, + opset_version=_onnx_opset_version, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + ) # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): - if isinstance(test_inputs, torch.Tensor) or \ - isinstance(test_inputs, list): + if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list): test_inputs = (test_inputs,) test_ouputs = model(*test_inputs) if isinstance(test_ouputs, torch.Tensor): @@ -117,9 +139,9 @@ class Module(torch.nn.Module): def forward(self, boxes, size): return ops.boxes.clip_boxes_to_image(boxes, size.shape) - self.run_model(Module(), [(boxes, size), (boxes, size_2)], - input_names=["boxes", "size"], - dynamic_axes={"size": [0, 1]}) + self.run_model( + Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]} + ) def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) @@ -184,11 +206,11 @@ def forward(self_module, images): input = torch.rand(3, 10, 20) input_test = torch.rand(3, 100, 150) - self.run_model(TransformModule(), [(input,), (input_test,)], - input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}) + self.run_model( + TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]} + ) def test_transform_images(self): - class TransformModule(torch.nn.Module): def __init__(self_module): super(TransformModule, self_module).__init__() @@ -225,11 +247,17 @@ def _init_test_rpn(self): rpn_score_thresh = 0.0 rpn = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, - score_thresh=rpn_score_thresh) + rpn_anchor_generator, + rpn_head, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_nms_thresh, + score_thresh=rpn_score_thresh, + ) return rpn def _init_test_roi_heads_faster_rcnn(self): @@ -245,38 +273,38 @@ def _init_test_roi_heads_faster_rcnn(self): box_nms_thresh = 0.5 box_detections_per_img = 100 - box_roi_pool = ops.MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - out_channels * resolution ** 2, - representation_size) + box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - num_classes) + box_predictor = FastRCNNPredictor(representation_size, num_classes) roi_heads = RoIHeads( - box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, + box_roi_pool, + box_head, + box_predictor, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + ) return roi_heads def get_features(self, images): s0, s1 = images.shape[-2:] features = [ - ('0', torch.rand(2, 256, s0 // 4, s1 // 4)), - ('1', torch.rand(2, 256, s0 // 8, s1 // 8)), - ('2', torch.rand(2, 256, s0 // 16, s1 // 16)), - ('3', torch.rand(2, 256, s0 // 32, s1 // 32)), - ('4', torch.rand(2, 256, s0 // 64, s1 // 64)), + ("0", torch.rand(2, 256, s0 // 4, s1 // 4)), + ("1", torch.rand(2, 256, s0 // 8, s1 // 8)), + ("2", torch.rand(2, 256, s0 // 16, s1 // 16)), + ("3", torch.rand(2, 256, s0 // 32, s1 // 32)), + ("4", torch.rand(2, 256, s0 // 64, s1 // 64)), ] features = OrderedDict(features) return features @@ -302,36 +330,56 @@ def forward(self_module, images, features): model.eval() model(images, features) - self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, - input_names=["input1", "input2", "input3", "input4", "input5", "input6"], - dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], - "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], - "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) + self.run_model( + model, + [(images, features), (images2, test_features)], + tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={ + "input1": [0, 1, 2, 3], + "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], + "input6": [0, 1, 2, 3], + }, + ) def test_multi_scale_roi_align(self): - class TransformModule(torch.nn.Module): def __init__(self): super(TransformModule, self).__init__() - self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2) + self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2) self.image_sizes = [(512, 512)] def forward(self, input, boxes): return self.model(input, boxes, self.image_sizes) i = OrderedDict() - i['feat1'] = torch.rand(1, 5, 64, 64) - i['feat2'] = torch.rand(1, 5, 16, 16) + i["feat1"] = torch.rand(1, 5, 64, 64) + i["feat2"] = torch.rand(1, 5, 16, 16) boxes = torch.rand(6, 4) * 256 boxes[:, 2:] += boxes[:, :2] i1 = OrderedDict() - i1['feat1'] = torch.rand(1, 5, 64, 64) - i1['feat2'] = torch.rand(1, 5, 16, 16) + i1["feat1"] = torch.rand(1, 5, 64, 64) + i1["feat2"] = torch.rand(1, 5, 16, 16) boxes1 = torch.rand(6, 4) * 256 boxes1[:, 2:] += boxes1[:, :2] - self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) + self.run_model( + TransformModule(), + [ + ( + i, + [boxes], + ), + ( + i1, + [boxes1], + ), + ], + ) def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): @@ -346,9 +394,7 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) proposals, _ = self_module.rpn(images, features) detections, _ = self_module.roi_heads(features, proposals, images.image_sizes) - detections = self_module.transform.postprocess(detections, - images.image_sizes, - original_image_sizes) + detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes) return detections images = torch.rand(2, 3, 100, 100) @@ -360,15 +406,27 @@ def forward(self_module, images, features): model.eval() model(images, features) - self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, - input_names=["input1", "input2", "input3", "input4", "input5", "input6"], - dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], - "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) + self.run_model( + model, + [(images, features), (images2, test_features)], + tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={ + "input1": [0, 1, 2, 3], + "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], + "input6": [0, 1, 2, 3], + }, + ) def get_image_from_url(self, url, size=None): + from io import BytesIO + import requests from PIL import Image - from io import BytesIO + from torchvision import transforms data = requests.get(url) @@ -399,15 +457,23 @@ def test_faster_rcnn(self): model.eval() model(images) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"], - output_names=["outputs"], - dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) # Test exported model for an image with no detections on other images - self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"], - output_names=["outputs"], - dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_image,), (images,)], + input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) # Verify that paste_mask_in_image beahves the same in tracing. # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image @@ -423,11 +489,11 @@ def test_paste_mask_in_image(self): boxes *= 50 o_im_s = (100, 100) from torchvision.models.detection.roi_heads import paste_masks_in_image + out = paste_masks_in_image(masks, boxes, o_im_s) - jit_trace = torch.jit.trace(paste_masks_in_image, - (masks, boxes, - [torch.tensor(o_im_s[0]), - torch.tensor(o_im_s[1])])) + jit_trace = torch.jit.trace( + paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) + ) out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) assert torch.all(out.eq(out_trace)) @@ -438,6 +504,7 @@ def test_paste_mask_in_image(self): boxes2 *= 100 o_im_s2 = (200, 200) from torchvision.models.detection.roi_heads import paste_masks_in_image + out2 = paste_masks_in_image(masks2, boxes2, o_im_s2) out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]) @@ -450,20 +517,36 @@ def test_mask_rcnn(self): model.eval() model(images) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images,), (test_images,), (dummy_image,)], - input_names=["images_tensors"], - output_names=["boxes", "labels", "scores", "masks"], - dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], - "scores": [0], "masks": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], + output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={ + "images_tensors": [0, 1, 2], + "boxes": [0, 1], + "labels": [0], + "scores": [0], + "masks": [0, 1, 2], + }, + tolerate_small_mismatch=True, + ) # TODO: enable this test once dynamic model export is fixed # Test exported model for an image with no detections on other images - self.run_model(model, [(dummy_image,), (images,)], - input_names=["images_tensors"], - output_names=["boxes", "labels", "scores", "masks"], - dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], - "scores": [0], "masks": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_image,), (images,)], + input_names=["images_tensors"], + output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={ + "images_tensors": [0, 1, 2], + "boxes": [0, 1], + "labels": [0], + "scores": [0], + "masks": [0, 1, 2], + }, + tolerate_small_mismatch=True, + ) # Verify that heatmaps_to_keypoints behaves the same in tracing. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints @@ -477,6 +560,7 @@ def test_heatmaps_to_keypoints(self): maps = torch.rand(10, 1, 26, 26) rois = torch.rand(10, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out = heatmaps_to_keypoints(maps, rois) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) out_trace = jit_trace(maps, rois) @@ -487,6 +571,7 @@ def test_heatmaps_to_keypoints(self): maps2 = torch.rand(20, 2, 21, 21) rois2 = torch.rand(20, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out2 = heatmaps_to_keypoints(maps2, rois2) out_trace2 = jit_trace(maps2, rois2) @@ -499,29 +584,38 @@ def test_keypoint_rcnn(self): model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() model(images) - self.run_model(model, [(images,), (test_images,), (dummy_images,)], - input_names=["images_tensors"], - output_names=["outputs1", "outputs2", "outputs3", "outputs4"], - dynamic_axes={"images_tensors": [0, 1, 2]}, - tolerate_small_mismatch=True) - - self.run_model(model, [(dummy_images,), (test_images,)], - input_names=["images_tensors"], - output_names=["outputs1", "outputs2", "outputs3", "outputs4"], - dynamic_axes={"images_tensors": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) + + self.run_model( + model, + [(dummy_images,), (test_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) def test_shufflenet_v2_dynamic_axes(self): model = models.shufflenet_v2_x0_5(pretrained=True) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) - self.run_model(model, [(dummy_input,), (test_inputs,)], - input_names=["input_images"], - output_names=["output"], - dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_input,), (test_inputs,)], + input_names=["input_images"], + output_names=["output"], + dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}}, + tolerate_small_mismatch=True, + ) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index 5c2fc882902..275e0c051da 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,24 +1,25 @@ -from common_utils import needs_cuda, cpu_and_gpu, assert_equal import math from abc import ABC, abstractmethod -import pytest +from functools import lru_cache +from typing import Tuple import numpy as np +import pytest import torch -from functools import lru_cache from torch import Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair from torchvision import ops -from typing import Tuple + +from common_utils import assert_equal, cpu_and_gpu, needs_cuda class RoIOpTester(ABC): dtype = torch.float64 - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): x_dtype = self.dtype if x_dtype is None else x_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype @@ -28,33 +29,33 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) - rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) - [0, 0, 5, 4, 9], - [0, 5, 5, 9, 9], - [1, 0, 0, 9, 9]], - dtype=rois_dtype, device=device) + rois = torch.tensor( + [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) pool_h, pool_w = pool_size, pool_size y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype - gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1, - sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs) + gt_y = self.expected_fn( + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs + ) tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, device, contiguous): pool_size = 2 x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) - rois = torch.tensor([[0, 0, 0, 4, 4], # format is (xyxy) - [0, 0, 2, 3, 4], - [0, 2, 2, 4, 4]], - dtype=self.dtype, device=device) + rois = torch.tensor( + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) + ) def func(z): return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) @@ -65,8 +66,8 @@ def func(z): gradcheck(script_func, (x,)) @needs_cuda - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) @@ -105,8 +106,9 @@ def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_pool) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, - device=None, dtype=torch.float64): + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") @@ -119,7 +121,7 @@ def get_slice(k, block): for roi_idx, roi in enumerate(rois): batch_idx = int(roi[0]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) - roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] roi_h, roi_w = roi_x.shape[-2:] bin_h = roi_h / pool_h @@ -144,8 +146,9 @@ def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_pool) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, - device=None, dtype=torch.float64): + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") n_input_channels = x.size(1) @@ -159,7 +162,7 @@ def get_slice(k, block): for roi_idx, roi in enumerate(rois): batch_idx = int(roi[0]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) - roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] roi_height = max(i_end - i_begin, 1) roi_width = max(j_end - j_begin, 1) @@ -214,21 +217,32 @@ def bilinear_interpolate(data, y, x, snap_border=False): class TestRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): - return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio, aligned=aligned)(x, rois) + return ops.RoIAlign( + (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned + )(x, rois) def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_align) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, - device=None, dtype=torch.float64): + def expected_fn( + self, + in_data, + rois, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + aligned=False, + device=None, + dtype=torch.float64, + ): if device is None: device = torch.device("cpu") n_channels = in_data.size(1) out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device) - offset = 0.5 if aligned else 0. + offset = 0.5 if aligned else 0.0 for r, roi in enumerate(rois): batch_idx = int(roi[0]) @@ -262,21 +276,23 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): - super().test_forward(device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, - aligned=aligned) + super().test_forward( + device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned + ) @needs_cuda - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, aligned, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, - rois_dtype=rois_dtype) + self.test_forward( + torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + ) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) @@ -284,9 +300,9 @@ def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate return rois - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('scale, zero_point', ((1, 0), (2, 10), (0.1, 50))) - @pytest.mark.parametrize('qdtype', (torch.qint8, torch.quint8, torch.qint32)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50))) + @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32)) def test_qroialign(self, aligned, scale, zero_point, qdtype): """Make sure quantized version of RoIAlign is close to float version""" pool_size = 5 @@ -336,7 +352,7 @@ def test_qroialign(self, aligned, scale, zero_point, qdtype): # - any difference between qy and quantized_float_y is == scale diff_idx = torch.where(qy != quantized_float_y) num_diff = diff_idx[0].numel() - assert num_diff / qy.numel() < .05 + assert num_diff / qy.numel() < 0.05 abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) t_scale = torch.full_like(abs_diff, fill_value=scale) @@ -354,15 +370,15 @@ def test_qroi_align_multiple_images(self): class TestPSRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): - return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio)(x, rois) + return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_align) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, - sampling_ratio=-1, dtype=torch.float64): + def expected_fn( + self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") n_input_channels = in_data.size(1) @@ -405,15 +421,17 @@ def test_boxes_shape(self): class TestMultiScaleRoIAlign: def test_msroialign_repr(self): - fmap_names = ['0'] + fmap_names = ["0"] output_size = (7, 7) sampling_ratio = 2 # Pass mock feature map names t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) # Check integrity of object __repr__ attribute - expected_string = (f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " - f"sampling_ratio={sampling_ratio})") + expected_string = ( + f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " + f"sampling_ratio={sampling_ratio})" + ) assert repr(t) == expected_string @@ -458,9 +476,9 @@ def _create_tensors_with_iou(self, N, iou_thresh): scores = torch.rand(N) return boxes, scores - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_ref(self, iou): - err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}' + err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) keep_ref = self._reference_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) @@ -476,13 +494,13 @@ def test_nms_input_errors(self): with pytest.raises(RuntimeError): ops.nms(torch.rand(3, 4), torch.rand(4), 0.5) - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10))) def test_qnms(self, iou, scale, zero_point): # Note: we compare qnms vs nms instead of qnms vs reference implementation. # This is because with the int convertion, the trick used in _create_tensors_with_iou # doesn't really work (in fact, nms vs reference implem will also fail with ints) - err_msg = 'NMS and QNMS give different results for IoU={}' + err_msg = "NMS and QNMS give different results for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) scores *= 100 # otherwise most scores would be 0 or 1 after int convertion @@ -498,10 +516,10 @@ def test_qnms(self, iou, scale, zero_point): assert torch.allclose(qkeep, keep), err_msg.format(iou) @needs_cuda - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_cuda(self, iou, dtype=torch.float64): tol = 1e-3 if dtype is torch.half else 1e-5 - err_msg = 'NMS incompatible between CPU and CUDA for IoU={}' + err_msg = "NMS incompatible between CPU and CUDA for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) r_cpu = ops.nms(boxes, scores, iou) @@ -515,7 +533,7 @@ def test_nms_cuda(self, iou, dtype=torch.float64): assert is_eq, err_msg.format(iou) @needs_cuda - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): @@ -523,9 +541,13 @@ def test_autocast(self, iou, dtype): @needs_cuda def test_nms_cuda_float16(self): - boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() + boxes = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + ).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() iou_thres = 0.2 @@ -537,7 +559,7 @@ def test_batched_nms_implementations(self): """Make sure that both implementations of batched_nms yield identical results""" num_boxes = 1000 - iou_threshold = .9 + iou_threshold = 0.9 boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 @@ -601,8 +623,11 @@ def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilati if mask is not None: mask_value = mask[b, mask_idx, i, j] - out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] * - bilinear_interpolate(x[b, c_in, :, :], pi, pj)) + out[b, c_out, i, j] += ( + mask_value + * weight[c_out, c, di, dj] + * bilinear_interpolate(x[b, c_in, :, :], pi, pj) + ) out += bias.view(1, n_out_channels, 1, 1) return out @@ -628,14 +653,29 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True) - offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, - device=device, dtype=dtype, requires_grad=True) + offset = torch.randn( + batch_sz, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w, + device=device, + dtype=dtype, + requires_grad=True, + ) - mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, - device=device, dtype=dtype, requires_grad=True) + mask = torch.randn( + batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True + ) - weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, - device=device, dtype=dtype, requires_grad=True) + weight = torch.randn( + n_out_channels, + n_in_channels // n_weight_grps, + weight_h, + weight_w, + device=device, + dtype=dtype, + requires_grad=True, + ) bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True) @@ -647,9 +687,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): return x, weight, offset, mask, bias, stride, pad, dilation - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) - @pytest.mark.parametrize('batch_sz', (0, 33)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (0, 33)) def test_forward(self, device, contiguous, batch_sz, dtype=None): dtype = dtype or self.dtype x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) @@ -659,8 +699,9 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): groups = 2 tol = 2e-3 if dtype is torch.half else 1e-5 - layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) + layer = ops.DeformConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups + ).to(device=x.device, dtype=dtype) res = layer(x, offset, mask) weight = layer.weight.data @@ -668,7 +709,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) ) # no modulation test @@ -676,7 +717,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) ) def test_wrong_sizes(self): @@ -684,57 +725,72 @@ def test_wrong_sizes(self): out_channels = 2 kernel_size = (3, 2) groups = 2 - x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args('cpu', contiguous=True, - batch_sz=10, dtype=self.dtype) - layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups) + x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args( + "cpu", contiguous=True, batch_sz=10, dtype=self.dtype + ) + layer = ops.DeformConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups + ) with pytest.raises(RuntimeError, match="the shape of the offset"): wrong_offset = torch.rand_like(offset[:, :2]) layer(x, wrong_offset) - with pytest.raises(RuntimeError, match=r'mask.shape\[1\] is not valid'): + with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"): wrong_mask = torch.rand_like(mask[:, :2]) layer(x, offset, wrong_mask) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) - @pytest.mark.parametrize('batch_sz', (0, 33)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (0, 33)) def test_backward(self, device, contiguous, batch_sz): - x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, - batch_sz, self.dtype) + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( + device, contiguous, batch_sz, self.dtype + ) def func(x_, offset_, mask_, weight_, bias_): - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, - padding=padding, dilation=dilation, mask=mask_) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_ + ) gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) def func_no_mask(x_, offset_, weight_, bias_): - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, - padding=padding, dilation=dilation, mask=None) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None + ) gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) @torch.jit.script def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, - padding=pad_, dilation=dilation_, mask=mask_) - - gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), - (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_ + ) + + gradcheck( + lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), + (x, offset, mask, weight, bias), + nondet_tol=1e-5, + fast_mode=True, + ) @torch.jit.script def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, - padding=pad_, dilation=dilation_, mask=None) - - gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), - (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None + ) + + gradcheck( + lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), + (x, offset, weight, bias), + nondet_tol=1e-5, + fast_mode=True, + ) @needs_cuda - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("contiguous", (True, False)) def test_compare_cpu_cuda_grads(self, contiguous): # Test from https://github.com/pytorch/vision/issues/2598 # Run on CUDA only @@ -768,8 +824,8 @@ def test_compare_cpu_cuda_grads(self, contiguous): torch.testing.assert_close(true_cpu_grads, res_grads) @needs_cuda - @pytest.mark.parametrize('batch_sz', (0, 33)) - @pytest.mark.parametrize('dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, batch_sz, dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) @@ -792,11 +848,13 @@ def test_frozenbatchnorm2d_repr(self): def test_frozenbatchnorm2d_eps(self): sample_size = (4, 32, 28, 28) x = torch.rand(sample_size) - state_dict = dict(weight=torch.rand(sample_size[1]), - bias=torch.rand(sample_size[1]), - running_mean=torch.rand(sample_size[1]), - running_var=torch.rand(sample_size[1]), - num_batches_tracked=torch.tensor(100)) + state_dict = dict( + weight=torch.rand(sample_size[1]), + bias=torch.rand(sample_size[1]), + running_mean=torch.rand(sample_size[1]), + running_var=torch.rand(sample_size[1]), + num_batches_tracked=torch.tensor(100), + ) # Check that default eps is equal to the one of BN fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) @@ -824,17 +882,19 @@ class TestBoxConversion: def _get_box_sequences(): # Define here the argument type of `boxes` supported by region pooling operations box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) - box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float), - torch.tensor([[0, 0, 100, 100]], dtype=torch.float)] + box_list = [ + torch.tensor([[0, 0, 100, 100]], dtype=torch.float), + torch.tensor([[0, 0, 100, 100]], dtype=torch.float), + ] box_tuple = tuple(box_list) return box_tensor, box_list, box_tuple - @pytest.mark.parametrize('box_sequence', _get_box_sequences()) + @pytest.mark.parametrize("box_sequence", _get_box_sequences()) def test_check_roi_boxes_shape(self, box_sequence): # Ensure common sequences of tensors are supported ops._utils.check_roi_boxes_shape(box_sequence) - @pytest.mark.parametrize('box_sequence', _get_box_sequences()) + @pytest.mark.parametrize("box_sequence", _get_box_sequences()) def test_convert_boxes_to_roi_format(self, box_sequence): # Ensure common sequences of tensors yield the same result ref_tensor = None @@ -846,11 +906,11 @@ def test_convert_boxes_to_roi_format(self, box_sequence): class TestBox: def test_bbox_same(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) - exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) assert exp_xyxy.size() == torch.Size([4, 4]) assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy) @@ -860,10 +920,10 @@ def test_bbox_same(self): def test_bbox_xyxy_xywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) + exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) assert exp_xywh.size() == torch.Size([4, 4]) box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") @@ -876,10 +936,12 @@ def test_bbox_xyxy_xywh(self): def test_bbox_xyxy_cxcywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], - [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) + exp_cxcywh = torch.tensor( + [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float + ) assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") @@ -890,12 +952,14 @@ def test_bbox_xyxy_cxcywh(self): assert_equal(box_xyxy, box_tensor) def test_bbox_xywh_cxcywh(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float + ) # This is wrong - exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], - [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + exp_cxcywh = torch.tensor( + [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float + ) assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh") @@ -905,28 +969,30 @@ def test_bbox_xywh_cxcywh(self): box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") assert_equal(box_xywh, box_tensor) - @pytest.mark.parametrize('inv_infmt', ["xwyh", "cxwyh"]) - @pytest.mark.parametrize('inv_outfmt', ["xwcx", "xhwcy"]) + @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"]) + @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"]) def test_bbox_invalid(self, inv_infmt, inv_outfmt): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float + ) with pytest.raises(ValueError): ops.box_convert(box_tensor, inv_infmt, inv_outfmt) def test_bbox_convert_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) scripted_fn = torch.jit.script(ops.box_convert) TOLERANCE = 1e-3 box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") - scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') + scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh") torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") - scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') + scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh") torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) @@ -944,16 +1010,22 @@ def area_check(box, expected, tolerance=1e-4): # Check for float32 and float64 boxes for dtype in [torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) area_check(box_tensor, expected, tolerance=0.05) # Check for float16 box - box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5], - [285.25, 188.75, 1192.0, 851.0], - [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16) + box_tensor = torch.tensor( + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + dtype=torch.float16, + ) expected = torch.tensor([605113.875, 600495.1875, 592247.25]) area_check(box_tensor, expected) @@ -972,9 +1044,14 @@ def iou_check(box, expected, tolerance=1e-4): # Check for float boxes for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) @@ -993,12 +1070,17 @@ def gen_iou_check(box, expected, tolerance=1e-4): # Check for float boxes for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_transforms.py b/test/test_transforms.py index 74757bcb4e6..f25f1cfe7cb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,14 +1,19 @@ +import math import os +import random + +import numpy as np +import pytest +from PIL import Image + import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 -import math -import random -import numpy as np -import pytest -from PIL import Image + +from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes + try: import accimage except ImportError: @@ -19,15 +24,14 @@ except ImportError: stats = None -from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal - GRACE_HOPPER = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" +) class TestConvertImageDtype: - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(float_dtypes())) def test_float_to_float(self, input_dtype, output_dtype): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) @@ -44,15 +48,15 @@ def test_float_to_float(self, input_dtype, output_dtype): assert abs(actual_min - desired_min) < 1e-7 assert abs(actual_max - desired_max) < 1e-7 - @pytest.mark.parametrize('input_dtype', float_dtypes()) - @pytest.mark.parametrize('output_dtype', int_dtypes()) + @pytest.mark.parametrize("input_dtype", float_dtypes()) + @pytest.mark.parametrize("output_dtype", int_dtypes()) def test_float_to_int(self, input_dtype, output_dtype): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( - input_dtype == torch.float64 and output_dtype == torch.int64 + input_dtype == torch.float64 and output_dtype == torch.int64 ): with pytest.raises(RuntimeError): transform(input_image) @@ -68,8 +72,8 @@ def test_float_to_int(self, input_dtype, output_dtype): assert actual_min == desired_min assert actual_max == desired_max - @pytest.mark.parametrize('input_dtype', int_dtypes()) - @pytest.mark.parametrize('output_dtype', float_dtypes()) + @pytest.mark.parametrize("input_dtype", int_dtypes()) + @pytest.mark.parametrize("output_dtype", float_dtypes()) def test_int_to_float(self, input_dtype, output_dtype): input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) @@ -88,7 +92,7 @@ def test_int_to_float(self, input_dtype, output_dtype): assert abs(actual_max - desired_max) < 1e-7 assert actual_max <= desired_max - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes())) def test_dtype_int_to_int(self, input_dtype, output_dtype): input_max = torch.iinfo(input_dtype).max input_image = torch.tensor((0, input_max), dtype=input_dtype) @@ -120,7 +124,7 @@ def test_dtype_int_to_int(self, input_dtype, output_dtype): assert actual_min == desired_min assert actual_max == (desired_max + error_term) - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes())) def test_int_to_int_consistency(self, input_dtype, output_dtype): input_max = torch.iinfo(input_dtype).max input_image = torch.tensor((0, input_max), dtype=input_dtype) @@ -142,11 +146,10 @@ def test_int_to_int_consistency(self, input_dtype, output_dtype): @pytest.mark.skipif(accimage is None, reason="accimage not available") class TestAccImage: - def test_accimage_to_tensor(self): trans = transforms.ToTensor() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) torch.testing.assert_close(output, expected_output) @@ -154,22 +157,24 @@ def test_accimage_to_tensor(self): def test_accimage_pil_to_tensor(self): trans = transforms.PILToTensor() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() torch.testing.assert_close(output, expected_output) def test_accimage_resize(self): - trans = transforms.Compose([ - transforms.Resize(256, interpolation=Image.LINEAR), - transforms.ToTensor(), - ]) + trans = transforms.Compose( + [ + transforms.Resize(256, interpolation=Image.LINEAR), + transforms.ToTensor(), + ] + ) # Checking if Compose, Resize and ToTensor can be printed as string trans.__repr__() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() @@ -179,15 +184,17 @@ def test_accimage_resize(self): torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2) def test_accimage_crop(self): - trans = transforms.Compose([ - transforms.CenterCrop(256), - transforms.ToTensor(), - ]) + trans = transforms.Compose( + [ + transforms.CenterCrop(256), + transforms.ToTensor(), + ] + ) # Checking if Compose, CenterCrop and ToTensor can be printed as string trans.__repr__() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() @@ -195,8 +202,7 @@ def test_accimage_crop(self): class TestToTensor: - - @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize("channels", [1, 3, 4]) def test_to_tensor(self, channels): height, width = 4, 4 trans = transforms.ToTensor() @@ -218,7 +224,7 @@ def test_to_tensor(self, channels): # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() - img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + img = transforms.ToPILImage()(input_data.mul(255)).convert("1") output = trans(img) torch.testing.assert_close(input_data, output, check_dtype=False) @@ -235,7 +241,7 @@ def test_to_tensor_errors(self): with pytest.raises(ValueError): trans(np.random.rand(1, 1, height, width)) - @pytest.mark.parametrize('dtype', [torch.float16, torch.float, torch.double]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.float, torch.double]) def test_to_tensor_with_other_default_dtypes(self, dtype): current_def_dtype = torch.get_default_dtype() @@ -249,7 +255,7 @@ def test_to_tensor_with_other_default_dtypes(self, dtype): torch.set_default_dtype(current_def_dtype) - @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize("channels", [1, 3, 4]) def test_pil_to_tensor(self, channels): height, width = 4, 4 trans = transforms.PILToTensor() @@ -273,7 +279,7 @@ def test_pil_to_tensor(self, channels): # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() - img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + img = transforms.ToPILImage()(input_data.mul(255)).convert("1") output = trans(img).view(torch.uint8).bool().to(torch.uint8) torch.testing.assert_close(input_data, output) @@ -305,34 +311,47 @@ def test_randomresized_params(): randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range) i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range) aspect_ratio_obtained = w / h - assert((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and - aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or - aspect_ratio_obtained == 1.0) + assert ( + min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained + and aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon + ) or aspect_ratio_obtained == 1.0 assert isinstance(i, int) assert isinstance(j, int) assert isinstance(h, int) assert isinstance(w, int) -@pytest.mark.parametrize('height, width', [ - # height, width - # square image - (28, 28), - (27, 27), - # rectangular image: h < w - (28, 34), - (29, 35), - # rectangular image: h > w - (34, 28), - (35, 29), -]) -@pytest.mark.parametrize('osize', [ - # single integer - 22, 27, 28, 36, - # single integer in tuple/list - [22, ], (27, ), -]) -@pytest.mark.parametrize('max_size', (None, 37, 1000)) +@pytest.mark.parametrize( + "height, width", + [ + # height, width + # square image + (28, 28), + (27, 27), + # rectangular image: h < w + (28, 34), + (29, 35), + # rectangular image: h > w + (34, 28), + (35, 29), + ], +) +@pytest.mark.parametrize( + "osize", + [ + # single integer + 22, + 27, + 28, + 36, + # single integer in tuple/list + [ + 22, + ], + (27,), + ], +) +@pytest.mark.parametrize("max_size", (None, 37, 1000)) def test_resize(height, width, osize, max_size): img = Image.new("RGB", size=(width, height), color=127) @@ -360,24 +379,36 @@ def test_resize(height, width, osize, max_size): assert result.size == (exp_w, exp_h), msg -@pytest.mark.parametrize('height, width', [ - # height, width - # square image - (28, 28), - (27, 27), - # rectangular image: h < w - (28, 34), - (29, 35), - # rectangular image: h > w - (34, 28), - (35, 29), -]) -@pytest.mark.parametrize('osize', [ - # two integers sequence output - [22, 22], [22, 28], [22, 36], - [27, 22], [36, 22], [28, 28], - [28, 37], [37, 27], [37, 37] -]) +@pytest.mark.parametrize( + "height, width", + [ + # height, width + # square image + (28, 28), + (27, 27), + # rectangular image: h < w + (28, 34), + (29, 35), + # rectangular image: h > w + (34, 28), + (35, 29), + ], +) +@pytest.mark.parametrize( + "osize", + [ + # two integers sequence output + [22, 22], + [22, 28], + [22, 36], + [27, 22], + [36, 22], + [28, 28], + [28, 37], + [37, 27], + [37, 37], + ], +) def test_resize_sequence_output(height, width, osize): img = Image.new("RGB", size=(width, height), color=127) oheight, owidth = osize @@ -398,18 +429,19 @@ def test_resize_antialias_error(): class TestPad: - def test_pad(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 img = torch.ones(3, height, width) padding = random.randint(1, 20) fill = random.randint(1, 50) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Pad(padding, fill=fill), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Pad(padding, fill=fill), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == height + 2 * padding assert result.size(2) == width + 2 * padding # check that all elements in the padded region correspond @@ -418,14 +450,9 @@ def test_pad(self): eps = 1e-5 h_padded = result[:, :padding, :] w_padded = result[:, :, :padding] - torch.testing.assert_close( - h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps - ) - torch.testing.assert_close( - w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps - ) - pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), - transforms.ToPILImage()(img)) + torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps) + torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps) + pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img)) def test_pad_with_tuple_of_pad_values(self): height = random.randint(10, 32) * 2 @@ -452,7 +479,7 @@ def test_pad_with_non_constant_padding_modes(self): img = F.pad(img, 1, (200, 200, 200)) # pad 3 to all sidess - edge_padded_img = F.pad(img, 3, padding_mode='edge') + edge_padded_img = F.pad(img, 3, padding_mode="edge") # First 6 elements of leftmost edge in the middle of the image, values are in order: # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -460,7 +487,7 @@ def test_pad_with_non_constant_padding_modes(self): assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) # Pad 3 to left/right, 2 to top/bottom - reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect') + reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect") # First 6 elements of leftmost edge in the middle of the image, values are in order: # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -468,7 +495,7 @@ def test_pad_with_non_constant_padding_modes(self): assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) # Pad 3 to left, 2 to top, 2 to right, 1 to bottom - symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric') + symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric") # First 6 elements of leftmost edge in the middle of the image, values are in order: # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -478,7 +505,7 @@ def test_pad_with_non_constant_padding_modes(self): # Check negative padding explicitly for symmetric case, since it is not # implemented for tensor case to compare to # Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom - symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric') + symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode="symmetric") symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8)) @@ -505,14 +532,18 @@ def test_pad_with_mode_F_images(self): @pytest.mark.skipif(stats is None, reason="scipy.stats not available") -@pytest.mark.parametrize('fn, trans, config', [ - (F.invert, transforms.RandomInvert, {}), - (F.posterize, transforms.RandomPosterize, {"bits": 4}), - (F.solarize, transforms.RandomSolarize, {"threshold": 192}), - (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), - (F.autocontrast, transforms.RandomAutocontrast, {}), - (F.equalize, transforms.RandomEqualize, {})]) -@pytest.mark.parametrize('p', (.5, .7)) +@pytest.mark.parametrize( + "fn, trans, config", + [ + (F.invert, transforms.RandomInvert, {}), + (F.posterize, transforms.RandomPosterize, {"bits": 4}), + (F.solarize, transforms.RandomSolarize, {"threshold": 192}), + (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), + (F.autocontrast, transforms.RandomAutocontrast, {}), + (F.equalize, transforms.RandomEqualize, {}), + ], +) +@pytest.mark.parametrize("p", (0.5, 0.7)) def test_randomness(fn, trans, config, p): random_state = random.getstate() random.seed(42) @@ -535,43 +566,42 @@ def test_randomness(fn, trans, config, p): class TestToPil: - def _get_1_channel_tensor_various_types(): img_data_float = torch.Tensor(1, 4, 4).uniform_() expected_output = img_data_float.mul(255).int().float().div(255).numpy() - yield img_data_float, expected_output, 'L' + yield img_data_float, expected_output, "L" img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255) expected_output = img_data_byte.float().div(255.0).numpy() - yield img_data_byte, expected_output, 'L' + yield img_data_byte, expected_output, "L" img_data_short = torch.ShortTensor(1, 4, 4).random_() expected_output = img_data_short.numpy() - yield img_data_short, expected_output, 'I;16' + yield img_data_short, expected_output, "I;16" img_data_int = torch.IntTensor(1, 4, 4).random_() expected_output = img_data_int.numpy() - yield img_data_int, expected_output, 'I' + yield img_data_int, expected_output, "I" def _get_2d_tensor_various_types(): img_data_float = torch.Tensor(4, 4).uniform_() expected_output = img_data_float.mul(255).int().float().div(255).numpy() - yield img_data_float, expected_output, 'L' + yield img_data_float, expected_output, "L" img_data_byte = torch.ByteTensor(4, 4).random_(0, 255) expected_output = img_data_byte.float().div(255.0).numpy() - yield img_data_byte, expected_output, 'L' + yield img_data_byte, expected_output, "L" img_data_short = torch.ShortTensor(4, 4).random_() expected_output = img_data_short.numpy() - yield img_data_short, expected_output, 'I;16' + yield img_data_short, expected_output, "I;16" img_data_int = torch.IntTensor(4, 4).random_() expected_output = img_data_int.numpy() - yield img_data_int, expected_output, 'I' + yield img_data_int, expected_output, "I" - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_1_channel_tensor_various_types()) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_1_channel_tensor_various_types()) def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() to_tensor = transforms.ToTensor() @@ -583,19 +613,22 @@ def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_outpu def test_1_channel_float_tensor_to_pil_image(self): img_data = torch.Tensor(1, 4, 4).uniform_() # 'F' mode for torch.FloatTensor - img_F_mode = transforms.ToPILImage(mode='F')(img_data) - assert img_F_mode.mode == 'F' + img_F_mode = transforms.ToPILImage(mode="F")(img_data) + assert img_F_mode.mode == "F" torch.testing.assert_close( - np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode='F')), np.array(img_F_mode) + np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode) ) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_mode', [ - (torch.Tensor(4, 4, 1).uniform_().numpy(), 'F'), - (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), 'L'), - (torch.ShortTensor(4, 4, 1).random_().numpy(), 'I;16'), - (torch.IntTensor(4, 4, 1).random_().numpy(), 'I'), - ]) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize( + "img_data, expected_mode", + [ + (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"), + (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"), + (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"), + (torch.IntTensor(4, 4, 1).random_().numpy(), "I"), + ], + ) def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) @@ -604,13 +637,13 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) - @pytest.mark.parametrize('expected_mode', [None, 'LA']) + @pytest.mark.parametrize("expected_mode", [None, "LA"]) def test_2_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'LA' # default should assume LA + assert img.mode == "LA" # default should assume LA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -624,19 +657,19 @@ def test_2_channel_ndarray_to_pil_image_error(self): # should raise if we try a mode for 4 or 1 or 3 channel images with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'LA']) + @pytest.mark.parametrize("expected_mode", [None, "LA"]) def test_2_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(2, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'LA' # default should assume LA + assert img.mode == "LA" # default should assume LA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -650,14 +683,14 @@ def test_2_channel_tensor_to_pil_image_error(self): # should raise if we try a mode for 4 or 1 or 3 channel images with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_2d_tensor_various_types()) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_2d_tensor_various_types()) def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() to_tensor = transforms.ToTensor() @@ -666,27 +699,30 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe assert img.mode == expected_mode torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0]) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_mode', [ - (torch.Tensor(4, 4).uniform_().numpy(), 'F'), - (torch.ByteTensor(4, 4).random_(0, 255).numpy(), 'L'), - (torch.ShortTensor(4, 4).random_().numpy(), 'I;16'), - (torch.IntTensor(4, 4).random_().numpy(), 'I'), - ]) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize( + "img_data, expected_mode", + [ + (torch.Tensor(4, 4).uniform_().numpy(), "F"), + (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"), + (torch.ShortTensor(4, 4).random_().numpy(), "I;16"), + (torch.IntTensor(4, 4).random_().numpy(), "I"), + ], + ) def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode np.testing.assert_allclose(img_data, img) - @pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) + @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) def test_3_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(3, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGB' # default should assume RGB + assert img.mode == "RGB" # default should assume RGB else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -699,22 +735,22 @@ def test_3_channel_tensor_to_pil_image_error(self): error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" # should raise if we try a mode for 4 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_()) - @pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) + @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) def test_3_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGB' # default should assume RGB + assert img.mode == "RGB" # default should assume RGB else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -731,20 +767,20 @@ def test_3_channel_ndarray_to_pil_image_error(self): error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" # should raise if we try a mode for 4 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) + @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"]) def test_4_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(4, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGBA' # default should assume RGBA + assert img.mode == "RGBA" # default should assume RGBA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -759,19 +795,19 @@ def test_4_channel_tensor_to_pil_image_error(self): error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" # should raise if we try a mode for 3 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) + @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"]) def test_4_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGBA' # default should assume RGBA + assert img.mode == "RGBA" # default should assume RGBA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -785,15 +821,15 @@ def test_4_channel_ndarray_to_pil_image_error(self): error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" # should raise if we try a mode for 3 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() - reg_msg = r'Input type \w+ is not supported' + reg_msg = r"Input type \w+ is not supported" with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.int64)) with pytest.raises(TypeError, match=reg_msg): @@ -803,15 +839,15 @@ def test_ndarray_bad_types_to_pil_image(self): with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.float64)) - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) - with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): + with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."): transforms.ToPILImage()(np.ones([4, 4, 6])) def test_tensor_bad_types_to_pil_image(self): - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) - with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): + with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."): transforms.ToPILImage()(torch.ones(6, 4, 4)) @@ -819,7 +855,7 @@ def test_adjust_brightness(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_brightness(x_pil, 1) @@ -845,7 +881,7 @@ def test_adjust_contrast(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_contrast(x_pil, 1) @@ -867,12 +903,12 @@ def test_adjust_contrast(): torch.testing.assert_close(y_np, y_ans) -@pytest.mark.skipif(Image.__version__ >= '7', reason="Temporarily disabled") +@pytest.mark.skipif(Image.__version__ >= "7", reason="Temporarily disabled") def test_adjust_saturation(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_saturation(x_pil, 1) @@ -898,7 +934,7 @@ def test_adjust_hue(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") with pytest.raises(ValueError): F.adjust_hue(x_pil, -0.7) @@ -929,11 +965,58 @@ def test_adjust_hue(): def test_adjust_sharpness(): x_shape = [4, 4, 3] - x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, - 0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, - 111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + x_data = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 0, + 0, + 65, + 108, + 101, + 120, + 97, + 110, + 100, + 101, + 114, + 32, + 86, + 114, + 121, + 110, + 105, + 111, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_sharpness(x_pil, 1) @@ -943,18 +1026,112 @@ def test_adjust_sharpness(): # test 1 y_pil = F.adjust_sharpness(x_pil, 0.5) y_np = np.array(y_pil) - y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, - 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, - 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 30, + 30, + 74, + 103, + 96, + 114, + 97, + 110, + 100, + 101, + 114, + 32, + 81, + 103, + 108, + 102, + 101, + 107, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) torch.testing.assert_close(y_np, y_ans) # test 2 y_pil = F.adjust_sharpness(x_pil, 2) y_np = np.array(y_pil) - y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, - 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, - 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 0, + 0, + 46, + 118, + 111, + 132, + 97, + 110, + 100, + 101, + 114, + 32, + 95, + 135, + 146, + 126, + 112, + 119, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) torch.testing.assert_close(y_np, y_ans) @@ -962,7 +1139,7 @@ def test_adjust_sharpness(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") x_th = torch.tensor(x_np.transpose(2, 0, 1)) y_pil = F.adjust_sharpness(x_pil, 2) y_np = np.array(y_pil).transpose(2, 0, 1) @@ -974,7 +1151,7 @@ def test_adjust_gamma(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_gamma(x_pil, 1) @@ -1000,15 +1177,15 @@ def test_adjusts_L_mode(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_rgb = Image.fromarray(x_np, mode='RGB') + x_rgb = Image.fromarray(x_np, mode="RGB") - x_l = x_rgb.convert('L') - assert F.adjust_brightness(x_l, 2).mode == 'L' - assert F.adjust_saturation(x_l, 2).mode == 'L' - assert F.adjust_contrast(x_l, 2).mode == 'L' - assert F.adjust_hue(x_l, 0.4).mode == 'L' - assert F.adjust_sharpness(x_l, 2).mode == 'L' - assert F.adjust_gamma(x_l, 0.5).mode == 'L' + x_l = x_rgb.convert("L") + assert F.adjust_brightness(x_l, 2).mode == "L" + assert F.adjust_saturation(x_l, 2).mode == "L" + assert F.adjust_contrast(x_l, 2).mode == "L" + assert F.adjust_hue(x_l, 0.4).mode == "L" + assert F.adjust_sharpness(x_l, 2).mode == "L" + assert F.adjust_gamma(x_l, 0.5).mode == "L" def test_rotate(): @@ -1047,7 +1224,7 @@ def test_rotate(): assert_equal(np.array(result_a), np.array(result_b)) -@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) +@pytest.mark.parametrize("mode", ["L", "RGB", "F"]) def test_rotate_fill(mode): img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB") @@ -1130,8 +1307,8 @@ def test_to_grayscale(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) # Test Set: Grayscale an image with desired number of output channels @@ -1139,16 +1316,16 @@ def test_to_grayscale(): trans1 = transforms.Grayscale(num_output_channels=1) gray_pil_1 = trans1(x_pil) gray_np_1 = np.array(gray_pil_1) - assert gray_pil_1.mode == 'L', 'mode should be L' - assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_1.mode == "L", "mode should be L" + assert gray_np_1.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_1) # Case 2: RGB -> 3 channel grayscale trans2 = transforms.Grayscale(num_output_channels=3) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np, gray_np_2[:, :, 0]) @@ -1157,16 +1334,16 @@ def test_to_grayscale(): trans3 = transforms.Grayscale(num_output_channels=1) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Case 4: 1 channel grayscale -> 3 channel grayscale trans4 = transforms.Grayscale(num_output_channels=3) gray_pil_4 = trans4(x_pil_2) gray_np_4 = np.array(gray_pil_4) - assert gray_pil_4.mode == 'RGB', 'mode should be RGB' - assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_4.mode == "RGB", "mode should be RGB" + assert gray_np_4.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) assert_equal(gray_np, gray_np_4[:, :, 0]) @@ -1184,8 +1361,8 @@ def test_random_grayscale(): random.seed(42) x_shape = [2, 2, 3] x_np = np.random.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) num_samples = 250 @@ -1193,9 +1370,11 @@ def test_random_grayscale(): for _ in range(num_samples): gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) gray_np_2 = np.array(gray_pil_2) - if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \ - np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \ - np.array_equal(gray_np, gray_np_2[:, :, 0]): + if ( + np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + and np.array_equal(gray_np, gray_np_2[:, :, 0]) + ): num_gray = num_gray + 1 p_value = stats.binom_test(num_gray, num_samples, p=0.5) @@ -1207,8 +1386,8 @@ def test_random_grayscale(): random.seed(42) x_shape = [2, 2, 3] x_np = np.random.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) num_samples = 250 @@ -1227,16 +1406,16 @@ def test_random_grayscale(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) # Case 3a: RGB -> 3 channel grayscale (grayscaled) trans2 = transforms.RandomGrayscale(p=1.0) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np, gray_np_2[:, :, 0]) @@ -1245,31 +1424,31 @@ def test_random_grayscale(): trans2 = transforms.RandomGrayscale(p=0.0) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(x_np, gray_np_2) # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) trans3 = transforms.RandomGrayscale(p=1.0) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) trans3 = transforms.RandomGrayscale(p=0.0) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Checking if RandomGrayscale can be printed as string trans3.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_apply(): random_state = random.getstate() random.seed(42) @@ -1278,7 +1457,8 @@ def test_random_apply(): transforms.RandomRotation((-45, 45)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), - ], p=0.75 + ], + p=0.75, ) img = transforms.ToPILImage()(torch.rand(3, 10, 10)) num_samples = 250 @@ -1296,16 +1476,12 @@ def test_random_apply(): random_apply_transform.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_choice(): random_state = random.getstate() random.seed(42) random_choice_transform = transforms.RandomChoice( - [ - transforms.Resize(15), - transforms.Resize(20), - transforms.CenterCrop(10) - ] + [transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)] ) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 @@ -1333,16 +1509,11 @@ def test_random_choice(): random_choice_transform.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_order(): random_state = random.getstate() random.seed(42) - random_order_transform = transforms.RandomOrder( - [ - transforms.Resize(20), - transforms.CenterCrop(10) - ] - ) + random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)]) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 num_normal_order = 0 @@ -1368,10 +1539,10 @@ def test_linear_transformation(): sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0) u, s, _ = np.linalg.svd(sigma.numpy()) zca_epsilon = 1e-10 # avoid division by 0 - d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon))) + d = torch.Tensor(np.diag(1.0 / np.sqrt(s + zca_epsilon))) u = torch.Tensor(u) principal_components = torch.mm(torch.mm(u, d), u.t()) - mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0)) + mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0) # initialize whitening matrix whitening = transforms.LinearTransformation(principal_components, mean_vector) # estimate covariance and mean using weak law of large number @@ -1384,16 +1555,18 @@ def test_linear_transformation(): cov += np.dot(xwhite, xwhite.T) / num_features mean += np.sum(xwhite) / num_features # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov - torch.testing.assert_close(cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, - msg="cov not close to 1") - torch.testing.assert_close(mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, - msg="mean not close to 0") + torch.testing.assert_close( + cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, msg="cov not close to 1" + ) + torch.testing.assert_close( + mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, msg="mean not close to 0" + ) # Checking if LinearTransformation can be printed as string whitening.__repr__() -@pytest.mark.parametrize('dtype', int_dtypes()) +@pytest.mark.parametrize("dtype", int_dtypes()) def test_max_value(dtype): assert F_t._max_value(dtype) == torch.iinfo(dtype).max @@ -1403,8 +1576,8 @@ def test_max_value(dtype): # self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max) -@pytest.mark.parametrize('should_vflip', [True, False]) -@pytest.mark.parametrize('single_dim', [True, False]) +@pytest.mark.parametrize("should_vflip", [True, False]) +@pytest.mark.parametrize("single_dim", [True, False]) def test_ten_crop(should_vflip, single_dim): to_pil_image = transforms.ToPILImage() h = random.randint(5, 25) @@ -1414,12 +1587,10 @@ def test_ten_crop(should_vflip, single_dim): if single_dim: crop_h = min(crop_h, crop_w) crop_w = crop_h - transform = transforms.TenCrop(crop_h, - vertical_flip=should_vflip) + transform = transforms.TenCrop(crop_h, vertical_flip=should_vflip) five_crop = transforms.FiveCrop(crop_h) else: - transform = transforms.TenCrop((crop_h, crop_w), - vertical_flip=should_vflip) + transform = transforms.TenCrop((crop_h, crop_w), vertical_flip=should_vflip) five_crop = transforms.FiveCrop((crop_h, crop_w)) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) @@ -1441,7 +1612,7 @@ def test_ten_crop(should_vflip, single_dim): assert results == expected_output -@pytest.mark.parametrize('single_dim', [True, False]) +@pytest.mark.parametrize("single_dim", [True, False]) def test_five_crop(single_dim): to_pil_image = transforms.ToPILImage() h = random.randint(5, 25) @@ -1465,16 +1636,16 @@ def test_five_crop(single_dim): to_pil_image = transforms.ToPILImage() tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) - tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) - bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) - br = to_pil_image(img[:, h - crop_h:, w - crop_w:]) + tr = to_pil_image(img[:, 0:crop_h, w - crop_w :]) + bl = to_pil_image(img[:, h - crop_h :, 0:crop_w]) + br = to_pil_image(img[:, h - crop_h :, w - crop_w :]) center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img)) expected_output = (tl, tr, bl, br, center) assert results == expected_output -@pytest.mark.parametrize('policy', transforms.AutoAugmentPolicy) -@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("policy", transforms.AutoAugmentPolicy) +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) def test_autoaugment(policy, fill): random.seed(42) img = Image.open(GRACE_HOPPER) @@ -1490,37 +1661,41 @@ def test_random_crop(): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 img = torch.ones(3, height, width) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == oheight assert result.size(2) == owidth padding = random.randint(1, 20) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((oheight, owidth), padding=padding), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((oheight, owidth), padding=padding), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == oheight assert result.size(2) == owidth - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((height, width)), - transforms.ToTensor() - ])(img) + result = transforms.Compose( + [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.ToTensor()] + )(img) assert result.size(1) == height assert result.size(2) == width torch.testing.assert_close(result, img) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == height + 1 assert result.size(2) == width + 1 @@ -1539,41 +1714,47 @@ def test_center_crop(): img = torch.ones(3, height, width) oh1 = (height - oheight) // 2 ow1 = (width - owidth) // 2 - imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] + imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth] imgnarrow.fill_(0) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) assert result.sum() == 0 oheight += 1 owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) sum1 = result.sum() assert sum1 > 1 oheight += 1 owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) sum2 = result.sum() assert sum2 > 0 assert sum2 > sum1 -@pytest.mark.parametrize('odd_image_size', (True, False)) -@pytest.mark.parametrize('delta', (1, 3, 5)) -@pytest.mark.parametrize('delta_width', (-2, -1, 0, 1, 2)) -@pytest.mark.parametrize('delta_height', (-2, -1, 0, 1, 2)) +@pytest.mark.parametrize("odd_image_size", (True, False)) +@pytest.mark.parametrize("delta", (1, 3, 5)) +@pytest.mark.parametrize("delta_width", (-2, -1, 0, 1, 2)) +@pytest.mark.parametrize("delta_height", (-2, -1, 0, 1, 2)) def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): - """ Tests when center crop size is larger than image size, along any dimension""" + """Tests when center crop size is larger than image size, along any dimension""" # Since height is independent of width, we can ignore images with odd height and even width and vice-versa. input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2) @@ -1587,10 +1768,8 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width) # Test both transforms, one with PIL input and one with tensor - output_pil = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop(crop_size), - transforms.ToTensor()], + output_pil = transforms.Compose( + [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.ToTensor()], )(img) assert output_pil.size()[1:3] == crop_size @@ -1615,14 +1794,14 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): output_center = output_pil[ :, - crop_center_tl[0]:crop_center_tl[0] + center_size[0], - crop_center_tl[1]:crop_center_tl[1] + center_size[1] + crop_center_tl[0] : crop_center_tl[0] + center_size[0], + crop_center_tl[1] : crop_center_tl[1] + center_size[1], ] img_center = img[ :, - input_center_tl[0]:input_center_tl[0] + center_size[0], - input_center_tl[1]:input_center_tl[1] + center_size[1] + input_center_tl[0] : input_center_tl[0] + center_size[0], + input_center_tl[1] : input_center_tl[1] + center_size[1], ] assert_equal(output_center, img_center) @@ -1634,8 +1813,8 @@ def test_color_jitter(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") for _ in range(10): y_pil = color_jitter(x_pil) @@ -1652,18 +1831,32 @@ def test_color_jitter(): def test_random_erasing(): img = torch.ones(3, 128, 128) - t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0)) + y, x, h, w, v = t.get_params( + img, + t.scale, + t.ratio, + [ + t.value, + ], + ) aspect_ratio = h / w # Add some tolerance due to the rounding and int conversion used in the transform tol = 0.05 - assert (1 / 3 - tol <= aspect_ratio <= 3 + tol) + assert 1 / 3 - tol <= aspect_ratio <= 3 + tol aspect_ratios = [] random.seed(42) trial = 1000 for _ in range(trial): - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params( + img, + t.scale, + t.ratio, + [ + t.value, + ], + ) aspect_ratios.append(h / w) count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) @@ -1690,11 +1883,11 @@ def test_random_rotation(): t = transforms.RandomRotation(10) angle = t.get_params(t.degrees) - assert (angle > -10 and angle < 10) + assert angle > -10 and angle < 10 t = transforms.RandomRotation((-10, 10)) angle = t.get_params(t.degrees) - assert (-10 < angle < 10) + assert -10 < angle < 10 # Checking if RandomRotation can be printed as string t.__repr__() @@ -1730,11 +1923,12 @@ def test_randomperspective(): tr_img = F.to_tensor(tr_img) assert img.size[0] == width assert img.size[1] == height - assert (torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > - torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img))) + assert torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > torch.nn.functional.mse_loss( + tr_img2, F.to_tensor(img) + ) -@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) +@pytest.mark.parametrize("mode", ["L", "RGB", "F"]) def test_randomperspective_fill(mode): # assert fill being either a Sequence or a Number @@ -1774,7 +1968,7 @@ def test_randomperspective_fill(mode): F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands)) -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_vertical_flip(): random_state = random.getstate() random.seed(42) @@ -1807,7 +2001,7 @@ def test_random_vertical_flip(): transforms.RandomVerticalFlip().__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_horizontal_flip(): random_state = random.getstate() random.seed(42) @@ -1840,10 +2034,10 @@ def test_random_horizontal_flip(): transforms.RandomHorizontalFlip().__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_normalize(): def samples_from_standard_normal(tensor): - p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue return p_value > 0.0001 random_state = random.getstate() @@ -1865,8 +2059,8 @@ def samples_from_standard_normal(tensor): assert_equal(tensor, tensor_inplace) -@pytest.mark.parametrize('dtype1', [torch.float32, torch.float64]) -@pytest.mark.parametrize('dtype2', [torch.int64, torch.float32, torch.float64]) +@pytest.mark.parametrize("dtype1", [torch.float32, torch.float64]) +@pytest.mark.parametrize("dtype2", [torch.int64, torch.float32, torch.float64]) def test_normalize_different_dtype(dtype1, dtype2): img = torch.rand(3, 10, 10, dtype=dtype1) mean = torch.tensor([1, 2, 3], dtype=dtype2) @@ -1887,15 +2081,15 @@ def test_normalize_3d_tensor(): mean_unsqueezed = mean.view(-1, 1, 1) std_unsqueezed = std.view(-1, 1, 1) result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed) - result2 = F.normalize(img, mean_unsqueezed.repeat(1, img_size, img_size), - std_unsqueezed.repeat(1, img_size, img_size)) + result2 = F.normalize( + img, mean_unsqueezed.repeat(1, img_size, img_size), std_unsqueezed.repeat(1, img_size, img_size) + ) torch.testing.assert_close(target, result1) torch.testing.assert_close(target, result2) class TestAffine: - - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def input_img(self): input_img = np.zeros((40, 40, 3), dtype=np.uint8) for pt in [(16, 16), (20, 16), (20, 20)]: @@ -1908,7 +2102,7 @@ def test_affine_translate_seq(self, input_img): with pytest.raises(TypeError, match=r"Argument translate should be a sequence"): F.affine(input_img, 10, translate=0, scale=1, shear=1) - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def pil_image(self, input_img): return F.to_pil_image(input_img) @@ -1929,33 +2123,29 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ rot = a_rad # 1) Check transformation matrix: - C = np.array([[1, 0, cx], - [0, 1, cy], - [0, 0, 1]]) - T = np.array([[1, 0, tx], - [0, 1, ty], - [0, 0, 1]]) + C = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) Cinv = np.linalg.inv(C) RS = np.array( - [[scale * math.cos(rot), -scale * math.sin(rot), 0], - [scale * math.sin(rot), scale * math.cos(rot), 0], - [0, 0, 1]]) + [ + [scale * math.cos(rot), -scale * math.sin(rot), 0], + [scale * math.sin(rot), scale * math.cos(rot), 0], + [0, 0, 1], + ] + ) - SHx = np.array([[1, -math.tan(sx), 0], - [0, 1, 0], - [0, 0, 1]]) + SHx = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - SHy = np.array([[1, 0, 0], - [-math.tan(sy), 1, 0], - [0, 0, 1]]) + SHy = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) RSS = np.matmul(RS, np.matmul(SHy, SHx)) true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv))) - result_matrix = self._to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=angle, - translate=translate, scale=scale, shear=shear)) + result_matrix = self._to_3x3_inv( + F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear) + ) assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10 # 2) Perform inverse mapping: true_result = np.zeros((40, 40, 3), dtype=np.uint8) @@ -1977,38 +2167,49 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ np_result = np.array(result) n_diff_pixels = np.sum(np_result != true_result) / 3 # Accept 3 wrong pixels - error_msg = ("angle={}, translate={}, scale={}, shear={}\n".format(angle, translate, scale, shear) + - "n diff pixels={}\n".format(n_diff_pixels)) + error_msg = "angle={}, translate={}, scale={}, shear={}\n".format( + angle, translate, scale, shear + ) + "n diff pixels={}\n".format(n_diff_pixels) assert n_diff_pixels < 3, error_msg def test_transformation_discrete(self, pil_image, input_img): # Test rotation angle = 45 - self._test_transformation(angle=angle, translate=(0, 0), scale=1.0, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test translation translate = [10, 15] - self._test_transformation(angle=0.0, translate=translate, scale=1.0, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=translate, scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test scale scale = 1.2 - self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=scale, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=(0.0, 0.0), scale=scale, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test shear shear = [45.0, 25.0] - self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=1.0, - shear=shear, pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img + ) @pytest.mark.parametrize("angle", range(-90, 90, 36)) @pytest.mark.parametrize("translate", range(-10, 10, 5)) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) @pytest.mark.parametrize("shear", range(-15, 15, 5)) def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img): - self._test_transformation(angle=angle, translate=(translate, translate), scale=scale, - shear=(shear, shear), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + pil_image=pil_image, + input_img=input_img, + ) def test_random_affine(): @@ -2056,13 +2257,14 @@ def test_random_affine(): t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40]) for _ in range(100): - angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, - img_size=img.size) + angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size) assert -10 < angle < 10 - assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}" - .format(translations[0], img.size[0] * 0.5)) - assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, ("{} vs {}" - .format(translations[1], img.size[1] * 0.5)) + assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, "{} vs {}".format( + translations[0], img.size[0] * 0.5 + ) + assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, "{} vs {}".format( + translations[1], img.size[1] * 0.5 + ) assert 0.7 < scale < 1.3 assert -10 < shear[0] < 10 assert -20 < shear[1] < 40 @@ -2088,5 +2290,5 @@ def test_random_affine(): assert t.interpolation == transforms.InterpolationMode.BILINEAR -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0bf5d77716f..aea1c5883b3 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,24 +1,22 @@ import os -import torch -from torchvision import transforms as T -from torchvision.transforms import functional as F -from torchvision.transforms import InterpolationMode import numpy as np import pytest -from typing import Sequence +import torch +from torchvision import transforms as T +from torchvision.transforms import InterpolationMode, functional as F from common_utils import ( - get_tmp_dir, - int_dtypes, - float_dtypes, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, - cpu_and_gpu, assert_equal, + cpu_and_gpu, + float_dtypes, + get_tmp_dir, + int_dtypes, ) NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC @@ -94,108 +92,105 @@ def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_ _test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( - 'func,method,fn_kwargs,match_kwargs', [ + "func,method,fn_kwargs,match_kwargs", + [ (F.hflip, T.RandomHorizontalFlip, None, {}), (F.vflip, T.RandomVerticalFlip, None, {}), (F.invert, T.RandomInvert, None, {}), (F.posterize, T.RandomPosterize, {"bits": 4}, {}), (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}), (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}), - (F.autocontrast, T.RandomAutocontrast, None, {'test_exact_match': False, - 'agg_method': 'max', 'tol': (1 + 1e-5), - 'allowed_percentage_diff': .05}), - (F.equalize, T.RandomEqualize, None, {}) - ] + ( + F.autocontrast, + T.RandomAutocontrast, + None, + {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05}, + ), + (F.equalize, T.RandomEqualize, None, {}), + ], ) def test_random(func, method, device, fn_kwargs, match_kwargs): _test_op(func, method, device, fn_kwargs, fn_kwargs, **match_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) class TestColorJitter: - - @pytest.mark.parametrize('brightness', [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]) + @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]) def test_color_jitter_brightness(self, brightness, device): tol = 1.0 + 1e-10 meth_kwargs = {"brightness": brightness} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max" + T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max" ) - @pytest.mark.parametrize('contrast', [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]) + @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]) def test_color_jitter_contrast(self, contrast, device): tol = 1.0 + 1e-10 meth_kwargs = {"contrast": contrast} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max" + T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max" ) - @pytest.mark.parametrize('saturation', [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]) + @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]) def test_color_jitter_saturation(self, saturation, device): tol = 1.0 + 1e-10 meth_kwargs = {"saturation": saturation} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max" + T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max" ) - @pytest.mark.parametrize('hue', [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]) + @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]) def test_color_jitter_hue(self, hue, device): meth_kwargs = {"hue": hue} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=16.1, agg_method="max" + T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=16.1, agg_method="max" ) def test_color_jitter_all(self, device): # All 4 parameters together meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=12.1, agg_method="max" + T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=12.1, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('m', ["constant", "edge", "reflect", "symmetric"]) -@pytest.mark.parametrize('mul', [1, -1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"]) +@pytest.mark.parametrize("mul", [1, -1]) def test_pad(m, mul, device): fill = 127 if m == "constant" else 0 # Test functional.pad (PIL and Tensor) with padding as single int - _test_functional_op( - F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, - device=device - ) + _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device) # Test functional.pad and transforms.Pad with padding as [int, ] - fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + fn_kwargs = meth_kwargs = { + "padding": [ + mul * 2, + ], + "fill": fill, + "padding_mode": m, + } + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test functional.pad and transforms.Pad with padding as list fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test functional.pad and transforms.Pad with padding as tuple fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_crop(device): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} # Test transforms.RandomCrop with size and padding as tuple - meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } - _test_op( - F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + meth_kwargs = { + "size": (4, 5), + "padding": (4, 4), + "pad_if_needed": True, + } + _test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test transforms.functional.crop including outside the image area fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top @@ -214,35 +209,43 @@ def test_crop(device): _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('padding_config', [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, - {"padding_mode": "edge"}, - {"padding_mode": "reflect"} -]) -@pytest.mark.parametrize('size', [5, [5, ], [6, 6]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "padding_config", + [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + ], +) +@pytest.mark.parametrize( + "size", + [ + 5, + [ + 5, + ], + [6, 6], + ], +) def test_crop_pad(size, padding_config, device): config = dict(padding_config) config["size"] = size _test_class_op(T.RandomCrop, device, config) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_center_crop(device): fn_kwargs = {"output_size": (4, 5)} - meth_kwargs = {"size": (4, 5), } - _test_op( - F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, - meth_kwargs=meth_kwargs - ) + meth_kwargs = { + "size": (4, 5), + } + _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) fn_kwargs = {"output_size": (5,)} meth_kwargs = {"size": (5,)} - _test_op( - F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, - meth_kwargs=meth_kwargs - ) + _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device) # Test torchscript of transforms.CenterCrop with size as int f = T.CenterCrop(size=5) @@ -250,7 +253,11 @@ def test_center_crop(device): scripted_fn(tensor) # Test torchscript of transforms.CenterCrop with size as [int, ] - f = T.CenterCrop(size=[5, ]) + f = T.CenterCrop( + size=[ + 5, + ] + ) scripted_fn = torch.jit.script(f) scripted_fn(tensor) @@ -263,16 +270,29 @@ def test_center_crop(device): scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('fn, method, out_length', [ - # test_five_crop - (F.five_crop, T.FiveCrop, 5), - # test_ten_crop - (F.ten_crop, T.TenCrop, 10) -]) -@pytest.mark.parametrize('size', [(5,), [5, ], (4, 5), [4, 5]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fn, method, out_length", + [ + # test_five_crop + (F.five_crop, T.FiveCrop, 5), + # test_ten_crop + (F.ten_crop, T.TenCrop, 10), + ], +) +@pytest.mark.parametrize( + "size", + [ + (5,), + [ + 5, + ], + (4, 5), + [4, 5], + ], +) def test_x_crop(fn, method, out_length, size, device): - meth_kwargs = fn_kwargs = {'size': size} + meth_kwargs = fn_kwargs = {"size": size} scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=device) @@ -308,16 +328,20 @@ def test_x_crop(fn, method, out_length, size, device): assert_equal(transformed_img, transformed_batch[i, ...]) -@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"]) +@pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"]) def test_x_crop_save(method): - fn = getattr(T, method)(size=[5, ]) + fn = getattr(T, method)( + size=[ + 5, + ] + ) scripted_fn = torch.jit.script(fn) with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method))) class TestResize: - @pytest.mark.parametrize('size', [32, 34, 35, 36, 38]) + @pytest.mark.parametrize("size", [32, 34, 35, 36, 38]) def test_resize_int(self, size): # TODO: Minimal check for bug-fix, improve this later x = torch.rand(3, 32, 46) @@ -329,11 +353,21 @@ def test_resize_int(self, size): assert y.shape[1] == size assert y.shape[2] == int(size * 46 / 32) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64]) - @pytest.mark.parametrize('size', [[32, ], [32, 32], (32, 32), [34, 35]]) - @pytest.mark.parametrize('max_size', [None, 35, 1000]) - @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64]) + @pytest.mark.parametrize( + "size", + [ + [ + 32, + ], + [32, 32], + (32, 32), + [34, 35], + ], + ) + @pytest.mark.parametrize("max_size", [None, 35, 1000]) + @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) def test_resize_scripted(self, dt, size, max_size, interpolation, device): tensor, _ = _create_data(height=34, width=36, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -350,16 +384,34 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self): - transform = T.Resize(size=[32, ]) + transform = T.Resize( + size=[ + 32, + ] + ) s_transform = torch.jit.script(transform) with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_resize.pt")) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) - @pytest.mark.parametrize('ratio', [(0.75, 1.333), [0.75, 1.333]]) - @pytest.mark.parametrize('size', [(32,), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]) - @pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR, BICUBIC]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) + @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]]) + @pytest.mark.parametrize( + "size", + [ + (32,), + [ + 44, + ], + [ + 32, + ], + [32, 32], + (32, 32), + [44, 55], + ], + ) + @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) def test_resized_crop(self, scale, ratio, size, interpolation, device): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -369,7 +421,11 @@ def test_resized_crop(self, scale, ratio, size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resized_crop_save(self): - transform = T.RandomResizedCrop(size=[32, ]) + transform = T.RandomResizedCrop( + size=[ + 32, + ] + ) s_transform = torch.jit.script(transform) with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt")) @@ -385,7 +441,7 @@ def _test_random_affine_helper(device, **kwargs): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_random_affine(device): transform = T.RandomAffine(degrees=45.0) s_transform = torch.jit.script(transform) @@ -393,54 +449,76 @@ def test_random_affine(device): s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('shear', [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) def test_random_affine_shear(device, interpolation, shear): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) def test_random_affine_scale(device, interpolation, scale): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('translate', [(0.1, 0.2), [0.2, 0.1]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]]) def test_random_affine_translate(device, interpolation, translate): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) def test_random_affine_degrees(device, interpolation, degrees): _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_affine_fill(device, interpolation, fill): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('center', [(0, 0), [10, 10], None, (56, 44)]) -@pytest.mark.parametrize('expand', [True, False]) -@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)]) +@pytest.mark.parametrize("expand", [True, False]) +@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_rotate(device, center, expand, degrees, interpolation, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - transform = T.RandomRotation( - degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill - ) + transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill) s_transform = torch.jit.script(transform) _test_transform_vs_scripted(transform, s_transform, tensor) @@ -454,19 +532,27 @@ def test_random_rotate_save(): s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('distortion_scale', np.linspace(0.1, 1.0, num=20)) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20)) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_perspective(device, distortion_scale, interpolation, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - transform = T.RandomPerspective( - distortion_scale=distortion_scale, - interpolation=interpolation, - fill=fill - ) + transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill) s_transform = torch.jit.script(transform) _test_transform_vs_scripted(transform, s_transform, tensor) @@ -480,23 +566,19 @@ def test_random_perspective_save(): s_transform.save(os.path.join(tmp_dir, "t_perspective.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('Klass, meth_kwargs', [ - (T.Grayscale, {"num_output_channels": 1}), - (T.Grayscale, {"num_output_channels": 3}), - (T.RandomGrayscale, {}) -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "Klass, meth_kwargs", + [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})], +) def test_to_grayscale(device, Klass, meth_kwargs): tol = 1.0 + 1e-10 - _test_class_op( - Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max" - ) + _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max") -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('in_dtype', int_dtypes() + float_dtypes()) -@pytest.mark.parametrize('out_dtype', int_dtypes() + float_dtypes()) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes()) +@pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes()) def test_convert_image_dtype(device, in_dtype, out_dtype): tensor, _ = _create_data(26, 34, device=device) batch_tensors = torch.rand(4, 3, 44, 56, device=device) @@ -507,8 +589,9 @@ def test_convert_image_dtype(device, in_dtype, out_dtype): fn = T.ConvertImageDtype(dtype=out_dtype) scripted_fn = torch.jit.script(fn) - if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \ - (in_dtype == torch.float64 and out_dtype == torch.int64): + if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or ( + in_dtype == torch.float64 and out_dtype == torch.int64 + ): with pytest.raises(RuntimeError, match=r"cannot be performed safely"): _test_transform_vs_scripted(fn, scripted_fn, in_tensor) with pytest.raises(RuntimeError, match=r"cannot be performed safely"): @@ -526,9 +609,22 @@ def test_convert_image_dtype_save(): scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('policy', [policy for policy in T.AutoAugmentPolicy]) -@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy]) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_autoaugment(device, policy, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -548,14 +644,10 @@ def test_autoaugment_save(): s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( - 'config', [ - {"value": 0.2}, - {"value": "random"}, - {"value": (0.2, 0.2, 0.2)}, - {"value": "random", "ratio": (0.1, 0.2)} - ] + "config", + [{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}], ) def test_random_erasing(device, config): tensor, _ = _create_data(24, 32, channels=3, device=device) @@ -582,7 +674,7 @@ def test_random_erasing_with_invalid_data(): random_erasing(img) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_normalize(device): fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tensor, _ = _create_data(26, 34, device=device) @@ -602,7 +694,7 @@ def test_normalize(device): scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_linear_transformation(device): c, h, w = 3, 24, 32 @@ -629,14 +721,16 @@ def test_linear_transformation(device): scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_compose(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 - transforms = T.Compose([ - T.CenterCrop(10), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) + transforms = T.Compose( + [ + T.CenterCrop(10), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) s_transforms = torch.nn.Sequential(*transforms.transforms) scripted_fn = torch.jit.script(s_transforms) @@ -646,26 +740,36 @@ def test_compose(device): transformed_tensor_script = scripted_fn(tensor) assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms)) - t = T.Compose([ - lambda x: x, - ]) + t = T.Compose( + [ + lambda x: x, + ] + ) with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"): torch.jit.script(t) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_random_apply(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 - transforms = T.RandomApply([ - T.RandomHorizontalFlip(), - T.ColorJitter(), - ], p=0.4) - s_transforms = T.RandomApply(torch.nn.ModuleList([ - T.RandomHorizontalFlip(), - T.ColorJitter(), - ]), p=0.4) + transforms = T.RandomApply( + [ + T.RandomHorizontalFlip(), + T.ColorJitter(), + ], + p=0.4, + ) + s_transforms = T.RandomApply( + torch.nn.ModuleList( + [ + T.RandomHorizontalFlip(), + T.ColorJitter(), + ] + ), + p=0.4, + ) scripted_fn = torch.jit.script(s_transforms) torch.manual_seed(12) @@ -677,25 +781,30 @@ def test_random_apply(device): if device == "cpu": # Can't check this twice, otherwise # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply" - transforms = T.RandomApply([ - T.ColorJitter(), - ], p=0.3) + transforms = T.RandomApply( + [ + T.ColorJitter(), + ], + p=0.3, + ) with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"): torch.jit.script(transforms) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('meth_kwargs', [ - {"kernel_size": 3, "sigma": 0.75}, - {"kernel_size": 23, "sigma": [0.1, 2.0]}, - {"kernel_size": 23, "sigma": (0.1, 2.0)}, - {"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, - {"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, - {"kernel_size": [23], "sigma": 0.75} -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "meth_kwargs", + [ + {"kernel_size": 3, "sigma": 0.75}, + {"kernel_size": 23, "sigma": [0.1, 2.0]}, + {"kernel_size": 23, "sigma": (0.1, 2.0)}, + {"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, + {"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, + {"kernel_size": [23], "sigma": 0.75}, + ], +) def test_gaussian_blur(device, meth_kwargs): tol = 1.0 + 1e-10 _test_class_op( - T.GaussianBlur, meth_kwargs=meth_kwargs, - test_exact_match=False, device=device, agg_method="max", tol=tol + T.GaussianBlur, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, agg_method="max", tol=tol ) diff --git a/test/test_transforms_video.py b/test/test_transforms_video.py index 1b6b85a29ba..e3dc72f9681 100644 --- a/test/test_transforms_video.py +++ b/test/test_transforms_video.py @@ -1,9 +1,12 @@ -import torch -from torchvision.transforms import Compose -import pytest import random -import numpy as np import warnings + +import numpy as np +import pytest + +import torch +from torchvision.transforms import Compose + from common_utils import assert_equal try: @@ -17,8 +20,7 @@ import torchvision.transforms._transforms_video as transforms -class TestVideoTransforms(): - +class TestVideoTransforms: def test_random_crop_video(self): numFrames = random.randint(4, 128) height = random.randint(10, 32) * 2 @@ -26,10 +28,12 @@ def test_random_crop_video(self): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) - result = Compose([ - transforms.ToTensorVideo(), - transforms.RandomCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.RandomCropVideo((oheight, owidth)), + ] + )(clip) assert result.size(2) == oheight assert result.size(3) == owidth @@ -42,10 +46,12 @@ def test_random_resized_crop_video(self): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) - result = Compose([ - transforms.ToTensorVideo(), - transforms.RandomResizedCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.RandomResizedCropVideo((oheight, owidth)), + ] + )(clip) assert result.size(2) == oheight assert result.size(3) == owidth @@ -61,47 +67,56 @@ def test_center_crop_video(self): clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 oh1 = (height - oheight) // 2 ow1 = (width - owidth) // 2 - clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] + clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :] clipNarrow.fill_(0) - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) - - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) + + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert result.sum().item() == 0, msg oheight += 1 owidth += 1 - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) sum1 = result.sum() - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert sum1.item() > 1, msg oheight += 1 owidth += 1 - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) sum2 = result.sum() - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert sum2.item() > 1, msg assert sum2.item() > sum1.item(), msg - @pytest.mark.skipif(stats is None, reason='scipy.stats is not available') - @pytest.mark.parametrize('channels', [1, 3]) + @pytest.mark.skipif(stats is None, reason="scipy.stats is not available") + @pytest.mark.parametrize("channels", [1, 3]) def test_normalize_video(self, channels): def samples_from_standard_normal(tensor): - p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue return p_value > 0.0001 random_state = random.getstate() @@ -142,7 +157,7 @@ def test_to_tensor_video(self): trans.__repr__() - @pytest.mark.skipif(stats is None, reason='scipy.stats not available') + @pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_horizontal_flip_video(self): random_state = random.getstate() random.seed(42) @@ -174,5 +189,5 @@ def test_random_horizontal_flip_video(self): transforms.RandomHorizontalFlipVideo().__repr__() -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_utils.py b/test/test_utils.py index 37829b906f1..17578f531d3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,21 +1,21 @@ -import pytest -import numpy as np import os import sys import tempfile +from io import BytesIO + +import numpy as np +import pytest +from PIL import Image, ImageColor, __version__ as PILLOW_VERSION + import torch +import torchvision.transforms.functional as F import torchvision.utils as utils -from io import BytesIO -import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION, ImageColor from common_utils import assert_equal +PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) -PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) - -boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) +boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) def test_make_grid_not_inplace(): @@ -23,13 +23,13 @@ def test_make_grid_not_inplace(): t_clone = t.clone() utils.make_grid(t, normalize=False) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") utils.make_grid(t, normalize=True, scale_each=False) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") utils.make_grid(t, normalize=True, scale_each=True) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") def test_normalize_in_make_grid(): @@ -46,48 +46,48 @@ def test_normalize_in_make_grid(): rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) - assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1') - assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0') + assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1") + assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0") -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(2, 3, 64, 64) utils.save_image(t, f.name) - assert os.path.exists(f.name), 'The image is not present after save' + assert os.path.exists(f.name), "The image is not present after save" -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_single_pixel(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(1, 3, 1, 1) utils.save_image(t, f.name) - assert os.path.exists(f.name), 'The pixel image is not present after save' + assert os.path.exists(f.name), "The pixel image is not present after save" -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_file_object(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(2, 3, 64, 64) utils.save_image(t, f.name) img_orig = Image.open(f.name) fp = BytesIO() - utils.save_image(t, fp, format='png') + utils.save_image(t, fp, format="png") img_bytes = Image.open(fp) - assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') + assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_single_pixel_file_object(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(1, 3, 1, 1) utils.save_image(t, f.name) img_orig = Image.open(f.name) fp = BytesIO() - utils.save_image(t, fp, format='png') + utils.save_image(t, fp, format="png") img_bytes = Image.open(fp) - assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') + assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") def test_draw_boxes(): @@ -113,13 +113,7 @@ def test_draw_boxes(): assert_equal(img, img_cp) -@pytest.mark.parametrize('colors', [ - None, - ['red', 'blue', '#FF00FF', (1, 34, 122)], - 'red', - '#FF00FF', - (1, 34, 122) -]) +@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)]) def test_draw_boxes_colors(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) @@ -154,8 +148,7 @@ def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) with pytest.raises(ValueError, match="Tensor uint8 expected"): @@ -166,12 +159,15 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) -@pytest.mark.parametrize('colors', [ - None, - ['red', 'blue'], - ['#FF00FF', (1, 34, 122)], -]) -@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) +@pytest.mark.parametrize( + "colors", + [ + None, + ["red", "blue"], + ["#FF00FF", (1, 34, 122)], + ], +) +@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1)) def test_draw_segmentation_masks(colors, alpha): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 @@ -241,10 +237,10 @@ def test_draw_segmentation_masks_errors(): with pytest.raises(ValueError, match="There are more masks"): utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): - bad_colors = np.array(['red', 'blue']) # should be a list + bad_colors = np.array(["red", "blue"]) # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): - bad_colors = ('red', 'blue') # should be a list + bad_colors = ("red", "blue") # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 10a6c242a1e..cf239e7d6cf 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -5,13 +5,14 @@ from fractions import Fraction import numpy as np +from numpy.random import randint + import torch import torchvision.io as io -from numpy.random import randint from torchvision import set_video_backend from torchvision.io import _HAS_VIDEO_OPT -from common_utils import PY39_SKIP, assert_equal +from common_utils import PY39_SKIP, assert_equal try: import av @@ -106,18 +107,14 @@ } -DecoderResult = collections.namedtuple( - "DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase" -) +DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase") """av_seek_frame is imprecise so seek to a timestamp earlier by a margin The unit of margin is second""" seek_frame_margin = 0.25 -def _read_from_stream( - container, start_pts, end_pts, stream, stream_name, buffer_size=4 -): +def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4): """ Args: container: pyav container @@ -229,9 +226,7 @@ def _decode_frames_by_av_module( else: aframes = torch.empty((1, 0), dtype=torch.float32) - aframe_pts = torch.tensor( - [audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64 - ) + aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64) return DecoderResult( vframes=vframes, @@ -271,24 +266,27 @@ def _get_video_tensor(video_dir, video_file): @unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") class TestVideoReader(unittest.TestCase): def check_separate_decoding_result(self, tv_result, config): - """check the decoding results from TorchVision decoder - """ - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) - - video_duration = vduration.item() * Fraction( - vtimebase[0].item(), vtimebase[1].item() - ) + """check the decoding results from TorchVision decoder""" + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result + + video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) self.assertAlmostEqual(video_duration, config.duration, delta=0.5) self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) if asample_rate.numel() > 0: self.assertEqual(asample_rate.item(), config.audio_sample_rate) - audio_duration = aduration.item() * Fraction( - atimebase[0].item(), atimebase[1].item() - ) + audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) # check if pts of video frames are sorted in ascending order @@ -302,16 +300,12 @@ def check_separate_decoding_result(self, tv_result, config): def check_probe_result(self, result, config): vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result - video_duration = vduration.item() * Fraction( - vtimebase[0].item(), vtimebase[1].item() - ) + video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) self.assertAlmostEqual(video_duration, config.duration, delta=0.5) self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) if asample_rate.numel() > 0: self.assertEqual(asample_rate.item(), config.audio_sample_rate) - audio_duration = aduration.item() * Fraction( - atimebase[0].item(), atimebase[1].item() - ) + audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) def check_meta_result(self, result, config): @@ -330,10 +324,18 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config decoder or TorchVision decoder with getPtsOnly = 1 config: config of decoding results checker """ - vframes, vframe_pts, vtimebase, _vfps, _vduration, \ - aframes, aframe_pts, atimebase, _asample_rate, _aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + _vfps, + _vduration, + aframes, + aframe_pts, + atimebase, + _asample_rate, + _aduration, + ) = tv_result if isinstance(ref_result, list): # the ref_result is from new video_reader decoder ref_result = DecoderResult( @@ -346,32 +348,20 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config ) if vframes.numel() > 0 and ref_result.vframes.numel() > 0: - mean_delta = torch.mean( - torch.abs(vframes.float() - ref_result.vframes.float()) - ) + mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float())) self.assertAlmostEqual(mean_delta, 0, delta=8.0) - mean_delta = torch.mean( - torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()) - ) + mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())) self.assertAlmostEqual(mean_delta, 0, delta=1.0) assert_equal(vtimebase, ref_result.vtimebase) - if ( - config.check_aframes - and aframes.numel() > 0 - and ref_result.aframes.numel() > 0 - ): + if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0: """Audio stream is available and audio frame is required to return from decoder""" assert_equal(aframes, ref_result.aframes) - if ( - config.check_aframe_pts - and aframe_pts.numel() > 0 - and ref_result.aframe_pts.numel() > 0 - ): + if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0: """Audio stream is available""" assert_equal(aframe_pts, ref_result.aframe_pts) @@ -507,19 +497,25 @@ def test_read_video_from_file_read_single_stream_only(self): audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result self.assertEqual(vframes.numel() > 0, readVideoStream) self.assertEqual(vframe_pts.numel() > 0, readVideoStream) self.assertEqual(vtimebase.numel() > 0, readVideoStream) self.assertEqual(vfps.numel() > 0, readVideoStream) - expect_audio_data = ( - readAudioStream == 1 and config.audio_sample_rate is not None - ) + expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None self.assertEqual(aframes.numel() > 0, expect_audio_data) self.assertEqual(aframe_pts.numel() > 0, expect_audio_data) self.assertEqual(atimebase.numel() > 0, expect_audio_data) @@ -563,9 +559,7 @@ def test_read_video_from_file_rescale_min_dimension(self): audio_timebase_num, audio_timebase_den, ) - self.assertEqual( - min_dimension, min(tv_result[0].size(1), tv_result[0].size(2)) - ) + self.assertEqual(min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))) def test_read_video_from_file_rescale_max_dimension(self): """ @@ -605,9 +599,7 @@ def test_read_video_from_file_rescale_max_dimension(self): audio_timebase_num, audio_timebase_den, ) - self.assertEqual( - max_dimension, max(tv_result[0].size(1), tv_result[0].size(2)) - ) + self.assertEqual(max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))) def test_read_video_from_file_rescale_both_min_max_dimension(self): """ @@ -647,12 +639,8 @@ def test_read_video_from_file_rescale_both_min_max_dimension(self): audio_timebase_num, audio_timebase_den, ) - self.assertEqual( - min_dimension, min(tv_result[0].size(1), tv_result[0].size(2)) - ) - self.assertEqual( - max_dimension, max(tv_result[0].size(1), tv_result[0].size(2)) - ) + self.assertEqual(min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))) + self.assertEqual(max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))) def test_read_video_from_file_rescale_width(self): """ @@ -816,19 +804,23 @@ def test_read_video_from_file_audio_resampling(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result if aframes.numel() > 0: self.assertEqual(samples, asample_rate.item()) self.assertEqual(1, aframes.size(1)) # when audio stream is found - duration = ( - float(aframe_pts[-1]) - * float(atimebase[0]) - / float(atimebase[1]) - ) + duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) self.assertAlmostEqual( aframes.size(0), int(duration * asample_rate.item()), @@ -1056,10 +1048,18 @@ def test_read_video_in_range_from_memory(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) for num_frames in [4, 8, 16, 32, 64, 128]: @@ -1113,41 +1113,31 @@ def test_read_video_in_range_from_memory(self): ) # pass 3: decode frames in range using PyAv - video_timebase_av, audio_timebase_av = _get_timebase_by_av_module( - full_path - ) + video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path) video_start_pts_av = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - video_timebase_av.numerator, video_timebase_av.denominator - ), + Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.floor, ) video_end_pts_av = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - video_timebase_av.numerator, video_timebase_av.denominator - ), + Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.ceil, ) if audio_timebase_av: audio_start_pts = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - audio_timebase_av.numerator, audio_timebase_av.denominator - ), + Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.floor, ) audio_end_pts = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - audio_timebase_av.numerator, audio_timebase_av.denominator - ), + Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.ceil, ) @@ -1235,17 +1225,17 @@ def test_read_video_from_memory_scripted(self): # FUTURE: check value of video / audio frames def test_invalid_file(self): - set_video_backend('video_reader') + set_video_backend("video_reader") with self.assertRaises(RuntimeError): - io.read_video('foo.mp4') + io.read_video("foo.mp4") - set_video_backend('pyav') + set_video_backend("pyav") with self.assertRaises(RuntimeError): - io.read_video('foo.mp4') + io.read_video("foo.mp4") def test_audio_present(self): """Test if audio frames are returned with video_reader backend.""" - set_video_backend('video_reader') + set_video_backend("video_reader") for test_video, _ in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) container = av.open(full_path) diff --git a/test/test_videoapi.py b/test/test_videoapi.py index da73c7cd17d..96ec44c6f22 100644 --- a/test/test_videoapi.py +++ b/test/test_videoapi.py @@ -4,8 +4,8 @@ import torch import torchvision -from torchvision.io import _HAS_VIDEO_OPT, VideoReader from torchvision.datasets.utils import download_url +from torchvision.io import _HAS_VIDEO_OPT, VideoReader from common_utils import PY39_SKIP @@ -35,30 +35,16 @@ def fate(name, path="."): test_videos = { - "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( - duration=2.0, video_fps=30.0, audio_sample_rate=None - ), + "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None), "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( duration=2.0, video_fps=30.0, audio_sample_rate=None ), - "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( - duration=2.0, video_fps=30.0, audio_sample_rate=None - ), - "v_SoccerJuggling_g23_c01.avi": GroundTruth( - duration=8.0, video_fps=29.97, audio_sample_rate=None - ), - "v_SoccerJuggling_g24_c01.avi": GroundTruth( - duration=8.0, video_fps=29.97, audio_sample_rate=None - ), - "R6llTwEh07w.mp4": GroundTruth( - duration=10.0, video_fps=30.0, audio_sample_rate=44100 - ), - "SOX5yA1l24A.mp4": GroundTruth( - duration=11.0, video_fps=29.97, audio_sample_rate=48000 - ), - "WUzgd7C1pWA.mp4": GroundTruth( - duration=11.0, video_fps=29.97, audio_sample_rate=48000 - ), + "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None), + "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None), + "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None), + "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100), + "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000), + "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000), } @@ -83,13 +69,9 @@ def test_frame_reading(self): delta=0.1, ) - av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute( - 2, 0, 1 - ) + av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1) vr_array = vr_frame["data"] - mean_delta = torch.mean( - torch.abs(av_array.float() - vr_array.float()) - ) + mean_delta = torch.mean(torch.abs(av_array.float() - vr_array.float())) # on average the difference is very small and caused # by decoding (around 1%) # TODO: asses empirically how to set this? atm it's 1% @@ -110,9 +92,7 @@ def test_frame_reading(self): av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0) vr_array = vr_frame["data"] - max_delta = torch.max( - torch.abs(av_array.float() - vr_array.float()) - ) + max_delta = torch.max(torch.abs(av_array.float() - vr_array.float())) # we assure that there is never more than 1% difference in signal self.assertTrue(max_delta.item() < 0.001) @@ -125,12 +105,8 @@ def test_metadata(self): full_path = os.path.join(VIDEO_DIR, test_video) reader = VideoReader(full_path, "video") reader_md = reader.get_metadata() - self.assertAlmostEqual( - config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 - ) - self.assertAlmostEqual( - config.duration, reader_md["video"]["duration"][0], delta=0.5 - ) + self.assertAlmostEqual(config.video_fps, reader_md["video"]["fps"][0], delta=0.0001) + self.assertAlmostEqual(config.duration, reader_md["video"]["duration"][0], delta=0.5) def test_seek_start(self): for test_video, config in test_videos.items(): diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 34961e8684f..8cc1d344936 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -1,4 +1,3 @@ - import os.path as osp import torch diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 3cbdda7af7f..d3a5921183b 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,31 +1,28 @@ -import warnings import os - -from .extension import _HAS_OPS - -from torchvision import models -from torchvision import datasets -from torchvision import ops -from torchvision import transforms -from torchvision import utils -from torchvision import io +import warnings import torch +from torchvision import datasets, io, models, ops, transforms, utils + +from .extension import _HAS_OPS try: - from .version import __version__ # noqa: F401 + from .version import __version__ except ImportError: pass # Check if torchvision is being imported within the root folder -if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == - os.path.join(os.path.realpath(os.getcwd()), 'torchvision')): - message = ('You are importing torchvision within its own root folder ({}). ' - 'This is not expected to work and may give errors. Please exit the ' - 'torchvision project source and relaunch your python interpreter.') +if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join( + os.path.realpath(os.getcwd()), "torchvision" +): + message = ( + "You are importing torchvision within its own root folder ({}). " + "This is not expected to work and may give errors. Please exit the " + "torchvision project source and relaunch your python interpreter." + ) warnings.warn(message.format(os.getcwd())) -_image_backend = 'PIL' +_image_backend = "PIL" _video_backend = "pyav" @@ -40,9 +37,8 @@ def set_image_backend(backend): generally faster than PIL, but does not support as many operations. """ global _image_backend - if backend not in ['PIL', 'accimage']: - raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'" - .format(backend)) + if backend not in ["PIL", "accimage"]: + raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend)) _image_backend = backend @@ -71,14 +67,9 @@ def set_video_backend(backend): """ global _video_backend if backend not in ["pyav", "video_reader"]: - raise ValueError( - "Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend - ) + raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend) if backend == "video_reader" and not io._HAS_VIDEO_OPT: - message = ( - "video_reader video backend is not available." - " Please compile torchvision from source and try again" - ) + message = "video_reader video backend is not available." " Please compile torchvision from source and try again" warnings.warn(message) else: _video_backend = backend diff --git a/torchvision/_internally_replaced_utils.py b/torchvision/_internally_replaced_utils.py index 0ab3e4e3f15..6215a8a2903 100644 --- a/torchvision/_internally_replaced_utils.py +++ b/torchvision/_internally_replaced_utils.py @@ -1,5 +1,5 @@ -import os import importlib.machinery +import os def _download_file_from_remote_location(fpath: str, url: str) -> None: @@ -13,19 +13,19 @@ def _is_remote_location_available() -> bool: try: from torch.hub import load_state_dict_from_url except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url + from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: F401 def _get_extension_path(lib_name): lib_dir = os.path.dirname(__file__) - if os.name == 'nt': + if os.name == "nt": # Register the main torchvision library location on the default DLL path import ctypes import sys - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") prev_error_mode = kernel32.SetErrorMode(0x0001) if with_load_library_flags: @@ -42,10 +42,7 @@ def _get_extension_path(lib_name): kernel32.SetErrorMode(prev_error_mode) - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES - ) + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec(lib_name) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 3b4a5408ecf..ba8e6609325 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,40 +1,71 @@ -from .lsun import LSUN, LSUNClass -from .folder import ImageFolder, DatasetFolder -from .coco import CocoCaptions, CocoDetection +from .caltech import Caltech101, Caltech256 +from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 -from .stl10 import STL10 -from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST -from .svhn import SVHN -from .phototour import PhotoTour +from .cityscapes import Cityscapes +from .coco import CocoCaptions, CocoDetection from .fakedata import FakeData -from .semeion import SEMEION -from .omniglot import Omniglot -from .sbu import SBU from .flickr import Flickr8k, Flickr30k -from .voc import VOCSegmentation, VOCDetection -from .cityscapes import Cityscapes +from .folder import DatasetFolder, ImageFolder +from .hmdb51 import HMDB51 from .imagenet import ImageNet -from .caltech import Caltech101, Caltech256 -from .celeba import CelebA -from .widerface import WIDERFace +from .inaturalist import INaturalist +from .kinetics import Kinetics, Kinetics400 +from .kitti import Kitti +from .lsun import LSUN, LSUNClass +from .mnist import EMNIST, KMNIST, MNIST, QMNIST, FashionMNIST +from .omniglot import Omniglot +from .phototour import PhotoTour +from .places365 import Places365 from .sbd import SBDataset -from .vision import VisionDataset -from .usps import USPS -from .kinetics import Kinetics400, Kinetics -from .hmdb51 import HMDB51 +from .sbu import SBU +from .semeion import SEMEION +from .stl10 import STL10 +from .svhn import SVHN from .ucf101 import UCF101 -from .places365 import Places365 -from .kitti import Kitti -from .inaturalist import INaturalist +from .usps import USPS +from .vision import VisionDataset +from .voc import VOCDetection, VOCSegmentation +from .widerface import WIDERFace -__all__ = ('LSUN', 'LSUNClass', - 'ImageFolder', 'DatasetFolder', 'FakeData', - 'CocoCaptions', 'CocoDetection', - 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST', - 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', - 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', - 'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', - 'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101', - 'Places365', 'Kitti', "INaturalist" - ) +__all__ = ( + "LSUN", + "LSUNClass", + "ImageFolder", + "DatasetFolder", + "FakeData", + "CocoCaptions", + "CocoDetection", + "CIFAR10", + "CIFAR100", + "EMNIST", + "FashionMNIST", + "QMNIST", + "MNIST", + "KMNIST", + "STL10", + "SVHN", + "PhotoTour", + "SEMEION", + "Omniglot", + "SBU", + "Flickr8k", + "Flickr30k", + "VOCSegmentation", + "VOCDetection", + "Cityscapes", + "ImageNet", + "Caltech101", + "Caltech256", + "CelebA", + "WIDERFace", + "SBDataset", + "VisionDataset", + "USPS", + "Kinetics400", + "Kinetics", + "HMDB51", + "UCF101", + "Places365", + "Kitti", + "INaturalist", +) diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 1a254edb430..9d0f74d5492 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -1,10 +1,11 @@ -from PIL import Image import os import os.path -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image -from .vision import VisionDataset from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset class Caltech101(VisionDataset): @@ -31,28 +32,26 @@ class Caltech101(VisionDataset): """ def __init__( - self, - root: str, - target_type: Union[List[str], str] = "category", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + target_type: Union[List[str], str] = "category", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Caltech101, self).__init__(os.path.join(root, 'caltech101'), - transform=transform, - target_transform=target_transform) + super(Caltech101, self).__init__( + os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform + ) os.makedirs(self.root, exist_ok=True) if not isinstance(target_type, list): target_type = [target_type] - self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories.remove("BACKGROUND_Google") # this is not a real class @@ -60,10 +59,12 @@ def __init__( # For some reason, the category names in "101_ObjectCategories" and # "Annotations" do not always match. This is a manual map between the # two. Defaults to using same name, since most names are fine. - name_map = {"Faces": "Faces_2", - "Faces_easy": "Faces_3", - "Motorbikes": "Motorbikes_16", - "airplanes": "Airplanes_Side_2"} + name_map = { + "Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2", + } self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.index: List[int] = [] @@ -83,20 +84,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ import scipy.io - img = Image.open(os.path.join(self.root, - "101_ObjectCategories", - self.categories[self.y[index]], - "image_{:04d}.jpg".format(self.index[index]))) + img = Image.open( + os.path.join( + self.root, + "101_ObjectCategories", + self.categories[self.y[index]], + "image_{:04d}.jpg".format(self.index[index]), + ) + ) target: Any = [] for t in self.target_type: if t == "category": target.append(self.y[index]) elif t == "annotation": - data = scipy.io.loadmat(os.path.join(self.root, - "Annotations", - self.annotation_categories[self.y[index]], - "annotation_{:04d}.mat".format(self.index[index]))) + data = scipy.io.loadmat( + os.path.join( + self.root, + "Annotations", + self.annotation_categories[self.y[index]], + "annotation_{:04d}.mat".format(self.index[index]), + ) + ) target.append(data["obj_contour"]) target = tuple(target) if len(target) > 1 else target[0] @@ -117,19 +126,21 @@ def __len__(self) -> int: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", self.root, filename="101_ObjectCategories.tar.gz", - md5="b224c7392d521a49829488ab0f1120d9") + md5="b224c7392d521a49829488ab0f1120d9", + ) download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", self.root, filename="101_Annotations.tar", - md5="6f83eeb1f24d99cab4eb377263132c91") + md5="6f83eeb1f24d99cab4eb377263132c91", + ) def extra_repr(self) -> str: return "Target type: {target_type}".format(**self.__dict__) @@ -151,23 +162,22 @@ class Caltech256(VisionDataset): """ def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Caltech256, self).__init__(os.path.join(root, 'caltech256'), - transform=transform, - target_transform=target_transform) + super(Caltech256, self).__init__( + os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform + ) os.makedirs(self.root, exist_ok=True) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.index: List[int] = [] @@ -185,10 +195,14 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: Returns: tuple: (image, target) where target is index of the target class. """ - img = Image.open(os.path.join(self.root, - "256_ObjectCategories", - self.categories[self.y[index]], - "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) + img = Image.open( + os.path.join( + self.root, + "256_ObjectCategories", + self.categories[self.y[index]], + "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]), + ) + ) target = self.y[index] @@ -209,11 +223,12 @@ def __len__(self) -> int: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", self.root, filename="256_ObjectCategories.tar", - md5="67b4f42ca05d46448c6bb8ecd2220f6d") + md5="67b4f42ca05d46448c6bb8ecd2220f6d", + ) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 56588aaef57..c455bc464c0 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,12 +1,15 @@ -from collections import namedtuple import csv -from functools import partial -import torch import os +from collections import namedtuple +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + import PIL -from typing import Any, Callable, List, Optional, Union, Tuple + +import torch + +from .utils import check_integrity, download_file_from_google_drive, verify_str_arg from .vision import VisionDataset -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -57,16 +60,15 @@ class CelebA(VisionDataset): ] def __init__( - self, - root: str, - split: str = "train", - target_type: Union[List[str], str] = "attr", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + target_type: Union[List[str], str] = "attr", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(CelebA, self).__init__(root, transform=transform, - target_transform=target_transform) + super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) self.split = split if isinstance(target_type, list): self.target_type = target_type @@ -74,14 +76,13 @@ def __init__( self.target_type = [target_type] if not self.target_type and self.target_transform is not None: - raise RuntimeError('target_transform is specified but target_type is empty') + raise RuntimeError("target_transform is specified but target_type is empty") if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") split_map = { "train": 0, @@ -89,8 +90,7 @@ def __init__( "test": 2, "all": None, } - split_ = split_map[verify_str_arg(split.lower(), "split", - ("train", "valid", "test", "all"))] + split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] splits = self._load_csv("list_eval_partition.txt") identity = self._load_csv("identity_CelebA.txt") bbox = self._load_csv("list_bbox_celeba.txt", header=1) @@ -105,7 +105,7 @@ def __init__( self.landmarks_align = landmarks_align.data[mask] self.attr = attr.data[mask] # map from {-1, 1} to {0, 1} - self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor') + self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor") self.attr_names = attr.header def _load_csv( @@ -117,11 +117,11 @@ def _load_csv( fn = partial(os.path.join, self.root, self.base_folder) with open(fn(filename)) as csv_file: - data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True)) + data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True)) if header is not None: headers = data[header] - data = data[header + 1:] + data = data[header + 1 :] indices = [row[0] for row in data] data = [row[1:] for row in data] @@ -145,7 +145,7 @@ def download(self) -> None: import zipfile if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return for (file_id, md5, filename) in self.file_list: @@ -169,7 +169,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: target.append(self.landmarks_align[index, :]) else: # TODO: refactor with utils.verify_str_arg - raise ValueError("Target type \"{}\" is not recognized.".format(t)) + raise ValueError('Target type "{}" is not recognized.'.format(t)) if self.transform is not None: X = self.transform(X) @@ -189,4 +189,4 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Target type: {target_type}", "Split: {split}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 47b2bd41fb0..3ea51bfe8be 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -1,13 +1,15 @@ -from PIL import Image import os import os.path -import numpy as np import pickle -import torch from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset +import numpy as np +from PIL import Image + +import torch + from .utils import check_integrity, download_and_extract_archive +from .vision import VisionDataset class CIFAR10(VisionDataset): @@ -27,38 +29,38 @@ class CIFAR10(VisionDataset): downloaded again. """ - base_folder = 'cifar-10-batches-py' + + base_folder = "cifar-10-batches-py" url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" - tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + tgz_md5 = "c58f30108f718f92721af3b95e74349a" train_list = [ - ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], - ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], - ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], - ['data_batch_4', '634d18415352ddfa80567beed471001a'], - ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ["data_batch_1", "c99cafc152244af753f735de768cd75f"], + ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], + ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], + ["data_batch_4", "634d18415352ddfa80567beed471001a"], + ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], ] test_list = [ - ['test_batch', '40351d587109b95175f43aff81a1287e'], + ["test_batch", "40351d587109b95175f43aff81a1287e"], ] meta = { - 'filename': 'batches.meta', - 'key': 'label_names', - 'md5': '5ff9c542aee3614f3951f8cda6e48888', + "filename": "batches.meta", + "key": "label_names", + "md5": "5ff9c542aee3614f3951f8cda6e48888", } def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(CIFAR10, self).__init__(root, transform=transform, - target_transform=target_transform) + super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform) torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}") self.train = train # training set or test set @@ -67,8 +69,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") if self.train: downloaded_list = self.train_list @@ -81,13 +82,13 @@ def __init__( # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) - with open(file_path, 'rb') as f: - entry = pickle.load(f, encoding='latin1') - self.data.append(entry['data']) - if 'labels' in entry: - self.targets.extend(entry['labels']) + with open(file_path, "rb") as f: + entry = pickle.load(f, encoding="latin1") + self.data.append(entry["data"]) + if "labels" in entry: + self.targets.extend(entry["labels"]) else: - self.targets.extend(entry['fine_labels']) + self.targets.extend(entry["fine_labels"]) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC @@ -95,13 +96,14 @@ def __init__( self._load_meta() def _load_meta(self) -> None: - path = os.path.join(self.root, self.base_folder, self.meta['filename']) - if not check_integrity(path, self.meta['md5']): - raise RuntimeError('Dataset metadata file not found or corrupted.' + - ' You can use download=True to download it') - with open(path, 'rb') as infile: - data = pickle.load(infile, encoding='latin1') - self.classes = data[self.meta['key']] + path = os.path.join(self.root, self.base_folder, self.meta["filename"]) + if not check_integrity(path, self.meta["md5"]): + raise RuntimeError( + "Dataset metadata file not found or corrupted." + " You can use download=True to download it" + ) + with open(path, "rb") as infile: + data = pickle.load(infile, encoding="latin1") + self.classes = data[self.meta["key"]] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index: int) -> Tuple[Any, Any]: @@ -131,7 +133,7 @@ def __len__(self) -> int: def _check_integrity(self) -> bool: root = self.root - for fentry in (self.train_list + self.test_list): + for fentry in self.train_list + self.test_list: filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): @@ -140,7 +142,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) @@ -153,19 +155,20 @@ class CIFAR100(CIFAR10): This is a subclass of the `CIFAR10` Dataset. """ - base_folder = 'cifar-100-python' + + base_folder = "cifar-100-python" url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" - tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" train_list = [ - ['train', '16019d7e3df5f24257cddd939b257f8d'], + ["train", "16019d7e3df5f24257cddd939b257f8d"], ] test_list = [ - ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], ] meta = { - 'filename': 'meta', - 'key': 'fine_label_names', - 'md5': '7973b15100ade9c7d40fb424638fde48', + "filename": "meta", + "key": "fine_label_names", + "md5": "7973b15100ade9c7d40fb424638fde48", } diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index bed7524ac4f..05e63d3ecf8 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -1,12 +1,13 @@ import json import os from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from .utils import extract_archive, verify_str_arg, iterable_to_str -from .vision import VisionDataset from PIL import Image +from .utils import extract_archive, iterable_to_str, verify_str_arg +from .vision import VisionDataset + class Cityscapes(VisionDataset): """`Cityscapes `_ Dataset. @@ -57,60 +58,62 @@ class Cityscapes(VisionDataset): """ # Based on https://github.com/mcordts/cityscapesScripts - CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', - 'has_instances', 'ignore_in_eval', 'color']) + CityscapesClass = namedtuple( + "CityscapesClass", + ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"], + ) classes = [ - CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), - CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), - CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), - CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), - CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), - CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), - CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), - CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), - CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), - CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), - CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), - CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), - CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), - CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), - CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), - CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), - CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), - CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), - CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), - CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), - CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), - CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), - CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), - CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), - CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), - CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), - CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), - CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), - CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), - CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), + CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)), + CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)), + CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)), + CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)), + CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)), + CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)), + CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)), + CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)), + CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)), + CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)), + CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)), + CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)), + CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)), + CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)), + CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)), + CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)), + CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)), + CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)), + CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)), + CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)), + CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)), + CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)), + CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)), + CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)), + CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)), + CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)), + CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)), + CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)), + CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)), + CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)), ] def __init__( - self, - root: str, - split: str = "train", - mode: str = "fine", - target_type: Union[List[str], str] = "instance", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - transforms: Optional[Callable] = None, + self, + root: str, + split: str = "train", + mode: str = "fine", + target_type: Union[List[str], str] = "instance", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, ) -> None: super(Cityscapes, self).__init__(root, transforms, transform, target_transform) - self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' - self.images_dir = os.path.join(self.root, 'leftImg8bit', split) + self.mode = "gtFine" if mode == "fine" else "gtCoarse" + self.images_dir = os.path.join(self.root, "leftImg8bit", split) self.targets_dir = os.path.join(self.root, self.mode, split) self.target_type = target_type self.split = split @@ -122,35 +125,37 @@ def __init__( valid_modes = ("train", "test", "val") else: valid_modes = ("train", "train_extra", "val") - msg = ("Unknown value '{}' for argument split if mode is '{}'. " - "Valid values are {{{}}}.") + msg = "Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}." msg = msg.format(split, mode, iterable_to_str(valid_modes)) verify_str_arg(split, "split", valid_modes, msg) if not isinstance(target_type, list): self.target_type = [target_type] - [verify_str_arg(value, "target_type", - ("instance", "semantic", "polygon", "color")) - for value in self.target_type] + [ + verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) + for value in self.target_type + ] if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): - if split == 'train_extra': - image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) + if split == "train_extra": + image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainextra.zip")) else: - image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) + image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainvaltest.zip")) - if self.mode == 'gtFine': - target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) - elif self.mode == 'gtCoarse': - target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) + if self.mode == "gtFine": + target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, "_trainvaltest.zip")) + elif self.mode == "gtCoarse": + target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, ".zip")) if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=target_dir_zip, to_path=self.root) else: - raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' - ' specified "split" and "mode" are inside the "root" directory') + raise RuntimeError( + "Dataset not found or incomplete. Please make sure all required folders for the" + ' specified "split" and "mode" are inside the "root" directory' + ) for city in os.listdir(self.images_dir): img_dir = os.path.join(self.images_dir, city) @@ -158,8 +163,9 @@ def __init__( for file_name in os.listdir(img_dir): target_types = [] for t in self.target_type: - target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], - self._get_target_suffix(self.mode, t)) + target_name = "{}_{}".format( + file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t) + ) target_types.append(os.path.join(target_dir, target_name)) self.images.append(os.path.join(img_dir, file_name)) @@ -174,11 +180,11 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. """ - image = Image.open(self.images[index]).convert('RGB') + image = Image.open(self.images[index]).convert("RGB") targets: Any = [] for i, t in enumerate(self.target_type): - if t == 'polygon': + if t == "polygon": target = self._load_json(self.targets[index][i]) else: target = Image.open(self.targets[index][i]) @@ -197,19 +203,19 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) def _load_json(self, path: str) -> Dict[str, Any]: - with open(path, 'r') as file: + with open(path, "r") as file: data = json.load(file) return data def _get_target_suffix(self, mode: str, target_type: str) -> str: - if target_type == 'instance': - return '{}_instanceIds.png'.format(mode) - elif target_type == 'semantic': - return '{}_labelIds.png'.format(mode) - elif target_type == 'color': - return '{}_color.png'.format(mode) + if target_type == "instance": + return "{}_instanceIds.png".format(mode) + elif target_type == "semantic": + return "{}_labelIds.png".format(mode) + elif target_type == "color": + return "{}_color.png".format(mode) else: - return '{}_polygons.json'.format(mode) + return "{}_polygons.json".format(mode) diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index d59a23efb4d..f87f91c8a3c 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,8 +1,10 @@ -from .vision import VisionDataset -from PIL import Image import os import os.path -from typing import Any, Callable, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple + +from PIL import Image + +from .vision import VisionDataset class CocoDetection(VisionDataset): diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index ddb14505275..24c5c0d1722 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -1,7 +1,9 @@ -import torch from typing import Any, Callable, Optional, Tuple + +import torch +from torchvision import transforms + from .vision import VisionDataset -from .. import transforms class FakeData(VisionDataset): @@ -21,16 +23,17 @@ class FakeData(VisionDataset): """ def __init__( - self, - size: int = 1000, - image_size: Tuple[int, int, int] = (3, 224, 224), - num_classes: int = 10, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - random_offset: int = 0, + self, + size: int = 1000, + image_size: Tuple[int, int, int] = (3, 224, 224), + num_classes: int = 10, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + random_offset: int = 0, ) -> None: - super(FakeData, self).__init__(None, transform=transform, # type: ignore[arg-type] - target_transform=target_transform) + super(FakeData, self).__init__( + None, transform=transform, target_transform=target_transform # type: ignore[arg-type] + ) self.size = size self.num_classes = num_classes self.image_size = image_size diff --git a/torchvision/datasets/flickr.py b/torchvision/datasets/flickr.py index a3b3e411b6e..31cb68d4937 100644 --- a/torchvision/datasets/flickr.py +++ b/torchvision/datasets/flickr.py @@ -1,10 +1,11 @@ +import glob +import os from collections import defaultdict -from PIL import Image from html.parser import HTMLParser from typing import Any, Callable, Dict, List, Optional, Tuple -import glob -import os +from PIL import Image + from .vision import VisionDataset @@ -27,26 +28,26 @@ def __init__(self, root: str) -> None: def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: self.current_tag = tag - if tag == 'table': + if tag == "table": self.in_table = True def handle_endtag(self, tag: str) -> None: self.current_tag = None - if tag == 'table': + if tag == "table": self.in_table = False def handle_data(self, data: str) -> None: if self.in_table: - if data == 'Image Not Found': + if data == "Image Not Found": self.current_img = None - elif self.current_tag == 'a': - img_id = data.split('/')[-2] - img_id = os.path.join(self.root, img_id + '_*.jpg') + elif self.current_tag == "a": + img_id = data.split("/")[-2] + img_id = os.path.join(self.root, img_id + "_*.jpg") img_id = glob.glob(img_id)[0] self.current_img = img_id self.annotations[img_id] = [] - elif self.current_tag == 'li' and self.current_img: + elif self.current_tag == "li" and self.current_img: img_id = self.current_img self.annotations[img_id].append(data.strip()) @@ -64,14 +65,13 @@ class Flickr8k(VisionDataset): """ def __init__( - self, - root: str, - ann_file: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(Flickr8k, self).__init__(root, transform=transform, - target_transform=target_transform) + super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict @@ -93,7 +93,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img_id = self.ids[index] # Image - img = Image.open(img_id).convert('RGB') + img = Image.open(img_id).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -121,21 +121,20 @@ class Flickr30k(VisionDataset): """ def __init__( - self, - root: str, - ann_file: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(Flickr30k, self).__init__(root, transform=transform, - target_transform=target_transform) + super(Flickr30k, self).__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict self.annotations = defaultdict(list) with open(self.ann_file) as fh: for line in fh: - img_id, caption = line.strip().split('\t') + img_id, caption = line.strip().split("\t") self.annotations[img_id[:-2]].append(caption) self.ids = list(sorted(self.annotations.keys())) @@ -152,7 +151,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # Image filename = os.path.join(self.root, img_id) - img = Image.open(filename).convert('RGB') + img = Image.open(filename).convert("RGB") if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 1280495273c..bf91a03a5ec 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,10 +1,10 @@ -from .vision import VisionDataset +import os +import os.path +from typing import Any, Callable, Dict, List, Optional, Tuple, cast from PIL import Image -import os -import os.path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from .vision import VisionDataset def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: @@ -132,16 +132,15 @@ class DatasetFolder(VisionDataset): """ def __init__( - self, - root: str, - loader: Callable[[str], Any], - extensions: Optional[Tuple[str, ...]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: - super(DatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) + super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self.find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) @@ -186,9 +185,7 @@ def make_dataset( # prevent potential bug since make_dataset() would use the class_to_idx logic of the # find_classes() function, instead of using that of the find_classes() method, which # is potentially overridden and thus could have a different logic. - raise ValueError( - "The class_to_idx parameter cannot be None." - ) + raise ValueError("The class_to_idx parameter cannot be None.") return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: @@ -241,19 +238,20 @@ def __len__(self) -> int: return len(self.samples) -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) - return img.convert('RGB') + return img.convert("RGB") # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage + try: return accimage.Image(path) except IOError: @@ -263,7 +261,8 @@ def accimage_loader(path: str) -> Any: def default_loader(path: str) -> Any: from torchvision import get_image_backend - if get_image_backend() == 'accimage': + + if get_image_backend() == "accimage": return accimage_loader(path) else: return pil_loader(path) @@ -300,15 +299,19 @@ class ImageFolder(DatasetFolder): """ def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - loader: Callable[[str], Any] = default_loader, - is_valid_file: Optional[Callable[[str], bool]] = None, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, ): - super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, - transform=transform, - target_transform=target_transform, - is_valid_file=is_valid_file) + super(ImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + ) self.imgs = self.samples diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index 4912eb01600..bd1464646be 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -47,20 +47,33 @@ class HMDB51(VisionDataset): data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar" splits = { "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", - "md5": "15e67781e70dcfbdce2d7dbb9b3344b5" + "md5": "15e67781e70dcfbdce2d7dbb9b3344b5", } TRAIN_TAG = 1 TEST_TAG = 2 - def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, - frame_rate=None, fold=1, train=True, transform=None, - _precomputed_metadata=None, num_workers=1, _video_width=0, - _video_height=0, _video_min_dimension=0, _audio_samples=0): + def __init__( + self, + root, + annotation_path, + frames_per_clip, + step_between_clips=1, + frame_rate=None, + fold=1, + train=True, + transform=None, + _precomputed_metadata=None, + num_workers=1, + _video_width=0, + _video_height=0, + _video_min_dimension=0, + _audio_samples=0, + ): super(HMDB51, self).__init__(root) if fold not in (1, 2, 3): raise ValueError("fold should be between 1 and 3, got {}".format(fold)) - extensions = ('avi',) + extensions = ("avi",) self.classes, class_to_idx = find_classes(self.root) self.samples = make_dataset( self.root, diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 6dfc9bfebfd..1871f12ca23 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,17 +1,19 @@ -import warnings -from contextlib import contextmanager import os import shutil import tempfile -from typing import Any, Dict, List, Iterator, Optional, Tuple +import warnings +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Tuple + import torch + from .folder import ImageFolder from .utils import check_integrity, extract_archive, verify_str_arg ARCHIVE_META = { - 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), - 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), - 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf') + "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), + "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), + "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), } META_FILE = "meta.bin" @@ -38,15 +40,16 @@ class ImageNet(ImageFolder): targets (list): The class_index value for each image in the dataset """ - def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None: + def __init__(self, root: str, split: str = "train", download: Optional[str] = None, **kwargs: Any) -> None: if download is True: - msg = ("The dataset is no longer publicly accessible. You need to " - "download the archives externally and place them in the root " - "directory.") + msg = ( + "The dataset is no longer publicly accessible. You need to " + "download the archives externally and place them in the root " + "directory." + ) raise RuntimeError(msg) elif download is False: - msg = ("The use of the download flag is deprecated, since the dataset " - "is no longer publicly accessible.") + msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible." warnings.warn(msg, RuntimeWarning) root = self.root = os.path.expanduser(root) @@ -61,18 +64,16 @@ def __init__(self, root: str, split: str = 'train', download: Optional[str] = No self.wnids = self.classes self.wnid_to_idx = self.class_to_idx self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] - self.class_to_idx = {cls: idx - for idx, clss in enumerate(self.classes) - for cls in clss} + self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} def parse_archives(self) -> None: if not check_integrity(os.path.join(self.root, META_FILE)): parse_devkit_archive(self.root) if not os.path.isdir(self.split_folder): - if self.split == 'train': + if self.split == "train": parse_train_archive(self.root) - elif self.split == 'val': + elif self.split == "val": parse_val_archive(self.root) @property @@ -91,15 +92,19 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str if check_integrity(file): return torch.load(file) else: - msg = ("The meta file {} is not present in the root directory or is corrupted. " - "This file is automatically created by the ImageNet dataset.") + msg = ( + "The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset." + ) raise RuntimeError(msg.format(file, root)) def _verify_archive(root: str, file: str, md5: str) -> None: if not check_integrity(os.path.join(root, file), md5): - msg = ("The archive {} is not present in the root directory or is corrupted. " - "You need to download it externally and place it in {}.") + msg = ( + "The archive {} is not present in the root directory or is corrupted. " + "You need to download it externally and place it in {}." + ) raise RuntimeError(msg.format(file, root)) @@ -116,20 +121,18 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]: metafile = os.path.join(devkit_root, "data", "meta.mat") - meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] + meta = sio.loadmat(metafile, squeeze_me=True)["synsets"] nums_children = list(zip(*meta))[4] - meta = [meta[idx] for idx, num_children in enumerate(nums_children) - if num_children == 0] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0] idcs, wnids, classes = list(zip(*meta))[:3] - classes = [tuple(clss.split(', ')) for clss in classes] + classes = [tuple(clss.split(", ")) for clss in classes] idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} return idx_to_wnid, wnid_to_classes def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: - file = os.path.join(devkit_root, "data", - "ILSVRC2012_validation_ground_truth.txt") - with open(file, 'r') as txtfh: + file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") + with open(file, "r") as txtfh: val_idcs = txtfh.readlines() return [int(val_idx) for val_idx in val_idcs] diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index 7b9a911f823..627ec1ef48f 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -1,29 +1,30 @@ -from PIL import Image import os import os.path -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from PIL import Image -from .vision import VisionDataset from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"] DATASET_URLS = { - '2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz', - '2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz', - '2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz', - '2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz', - '2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz', - '2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz', + "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz", + "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz", + "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz", + "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz", + "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz", + "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz", } DATASET_MD5 = { - '2017': '7c784ea5e424efaec655bd392f87301f', - '2018': 'b1c6952ce38f31868cc50ea72d066cc3', - '2019': 'c60a6e2962c9b8ccbd458d12c8582644', - '2021_train': '38a7bb733f7a09214d44293460ec0021', - '2021_train_mini': 'db6ed8330e634445efc8fec83ae81442', - '2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc', + "2017": "7c784ea5e424efaec655bd392f87301f", + "2018": "b1c6952ce38f31868cc50ea72d066cc3", + "2019": "c60a6e2962c9b8ccbd458d12c8582644", + "2021_train": "38a7bb733f7a09214d44293460ec0021", + "2021_train_mini": "db6ed8330e634445efc8fec83ae81442", + "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc", } @@ -63,27 +64,26 @@ class INaturalist(VisionDataset): """ def __init__( - self, - root: str, - version: str = "2021_train", - target_type: Union[List[str], str] = "full", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + version: str = "2021_train", + target_type: Union[List[str], str] = "full", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) - super(INaturalist, self).__init__(os.path.join(root, version), - transform=transform, - target_transform=target_transform) + super(INaturalist, self).__init__( + os.path.join(root, version), transform=transform, target_transform=target_transform + ) os.makedirs(root, exist_ok=True) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.all_categories: List[str] = [] @@ -96,12 +96,10 @@ def __init__( if not isinstance(target_type, list): target_type = [target_type] if self.version[:4] == "2021": - self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type] self._init_2021() else: - self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type] self._init_pre2021() # index of all files: (full category id, filename) @@ -118,16 +116,14 @@ def _init_2021(self) -> None: self.all_categories = sorted(os.listdir(self.root)) # map: category type -> name of category -> index - self.categories_index = { - k: {} for k in CATEGORIES_2021 - } + self.categories_index = {k: {} for k in CATEGORIES_2021} for dir_index, dir_name in enumerate(self.all_categories): - pieces = dir_name.split('_') + pieces = dir_name.split("_") if len(pieces) != 8: - raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces') - if pieces[0] != f'{dir_index:05d}': - raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}') + raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces") + if pieces[0] != f"{dir_index:05d}": + raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}") cat_map = {} for cat, name in zip(CATEGORIES_2021, pieces[1:7]): if name in self.categories_index[cat]: @@ -142,7 +138,7 @@ def _init_pre2021(self) -> None: """Initialize based on 2017-2019 layout""" # map: category type -> name of category -> index - self.categories_index = {'super': {}} + self.categories_index = {"super": {}} cat_index = 0 super_categories = sorted(os.listdir(self.root)) @@ -165,7 +161,7 @@ def _init_pre2021(self) -> None: self.all_categories.extend([""] * (subcat_i - old_len + 1)) if self.categories_map[subcat_i]: raise RuntimeError(f"Duplicate category {subcat}") - self.categories_map[subcat_i] = {'super': sindex} + self.categories_map[subcat_i] = {"super": sindex} self.all_categories[subcat_i] = os.path.join(scat, subcat) # validate the dictionary @@ -183,9 +179,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ cat_id, fname = self.index[index] - img = Image.open(os.path.join(self.root, - self.all_categories[cat_id], - fname)) + img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname)) target: Any = [] for t in self.target_type: @@ -239,10 +233,8 @@ def download(self) -> None: base_root = os.path.dirname(self.root) download_and_extract_archive( - DATASET_URLS[self.version], - base_root, - filename=f"{self.version}.tgz", - md5=DATASET_MD5[self.version]) + DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version] + ) orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz")) if not os.path.exists(orig_dir_name): diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 2543b6c514d..fc91a28fa46 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,16 +1,14 @@ -import time +import csv import os +import time import warnings - - -from os import path -import csv -from typing import Any, Callable, Dict, Optional, Tuple from functools import partial from multiprocessing import Pool +from os import path +from typing import Any, Callable, Dict, Optional, Tuple -from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity from .folder import find_classes, make_dataset +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .video_utils import VideoClips from .vision import VisionDataset @@ -213,18 +211,13 @@ def _make_ds_structure(self): start=int(row["time_start"]), end=int(row["time_end"]), ) - label = ( - row["label"] - .replace(" ", "_") - .replace("'", "") - .replace("(", "") - .replace(")", "") - ) + label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "") os.makedirs(path.join(self.split_folder, label), exist_ok=True) downloaded_file = path.join(self.split_folder, f) if path.isfile(downloaded_file): os.replace( - downloaded_file, path.join(self.split_folder, label, f), + downloaded_file, + path.join(self.split_folder, label, f), ) @property @@ -302,11 +295,12 @@ def __init__( split: Any = None, download: Any = None, num_download_workers: Any = None, - **kwargs: Any + **kwargs: Any, ): warnings.warn( "Kinetics400 is deprecated and will be removed in a future release." - "It was replaced by Kinetics(..., num_classes=\"400\").") + 'It was replaced by Kinetics(..., num_classes="400").' + ) if any(value is not None for value in (num_classes, split, download, num_download_workers)): raise RuntimeError( "Usage of 'num_classes', 'split', 'download', or 'num_download_workers' is not supported in " diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 8db2e45b715..6120dcf65ae 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -71,9 +71,7 @@ def __init__( if download: self.download() if not self._check_exists(): - raise RuntimeError( - "Dataset not found. You may use download=True to download it." - ) + raise RuntimeError("Dataset not found. You may use download=True to download it.") image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) if self.train: @@ -81,9 +79,7 @@ def __init__( for img_file in os.listdir(image_dir): self.images.append(os.path.join(image_dir, img_file)) if self.train: - self.targets.append( - os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") - ) + self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")) def __getitem__(self, index: int) -> Tuple[Any, Any]: """Get item at a given index. @@ -115,16 +111,18 @@ def _parse_target(self, index: int) -> List: with open(self.targets[index]) as inp: content = csv.reader(inp, delimiter=" ") for line in content: - target.append({ - "type": line[0], - "truncated": float(line[1]), - "occluded": int(line[2]), - "alpha": float(line[3]), - "bbox": [float(x) for x in line[4:8]], - "dimensions": [float(x) for x in line[8:11]], - "location": [float(x) for x in line[11:14]], - "rotation_y": float(line[14]), - }) + target.append( + { + "type": line[0], + "truncated": float(line[1]), + "occluded": int(line[2]), + "alpha": float(line[3]), + "bbox": [float(x) for x in line[4:8]], + "dimensions": [float(x) for x in line[8:11]], + "location": [float(x) for x in line[11:14]], + "rotation_y": float(line[14]), + } + ) return target def __len__(self) -> int: @@ -139,10 +137,7 @@ def _check_exists(self) -> bool: folders = [self.image_dir_name] if self.train: folders.append(self.labels_dir_name) - return all( - os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) - for fname in folders - ) + return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders) def download(self) -> None: """Download the KITTI data if it doesn't exist already.""" diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 75b284b597f..87979300de6 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -1,29 +1,29 @@ -from .vision import VisionDataset -from PIL import Image +import io import os import os.path -import io +import pickle import string from collections.abc import Iterable -import pickle -from typing import Any, Callable, cast, List, Optional, Tuple, Union -from .utils import verify_str_arg, iterable_to_str +from typing import Any, Callable, List, Optional, Tuple, Union, cast + +from PIL import Image + +from .utils import iterable_to_str, verify_str_arg +from .vision import VisionDataset class LSUNClass(VisionDataset): def __init__( - self, root: str, transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None + self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: import lmdb - super(LSUNClass, self).__init__(root, transform=transform, - target_transform=target_transform) - self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, - readahead=False, meminit=False) + super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform) + + self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: - self.length = txn.stat()['entries'] - cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) + self.length = txn.stat()["entries"] + cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) if os.path.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: @@ -40,7 +40,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: buf = io.BytesIO() buf.write(imgbuf) buf.seek(0) - img = Image.open(buf).convert('RGB') + img = Image.open(buf).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -69,22 +69,19 @@ class LSUN(VisionDataset): """ def __init__( - self, - root: str, - classes: Union[str, List[str]] = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + classes: Union[str, List[str]] = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(LSUN, self).__init__(root, transform=transform, - target_transform=target_transform) + super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform) self.classes = self._verify_classes(classes) # for each class, create an LSUNClassDataset self.dbs = [] for c in self.classes: - self.dbs.append(LSUNClass( - root=os.path.join(root, f"{c}_lmdb"), - transform=transform)) + self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform)) self.indices = [] count = 0 @@ -95,35 +92,41 @@ def __init__( self.length = count def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]: - categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', - 'conference_room', 'dining_room', 'kitchen', - 'living_room', 'restaurant', 'tower'] - dset_opts = ['train', 'val', 'test'] + categories = [ + "bedroom", + "bridge", + "church_outdoor", + "classroom", + "conference_room", + "dining_room", + "kitchen", + "living_room", + "restaurant", + "tower", + ] + dset_opts = ["train", "val", "test"] try: classes = cast(str, classes) verify_str_arg(classes, "classes", dset_opts) - if classes == 'test': + if classes == "test": classes = [classes] else: - classes = [c + '_' + classes for c in categories] + classes = [c + "_" + classes for c in categories] except ValueError: if not isinstance(classes, Iterable): - msg = ("Expected type str or Iterable for argument classes, " - "but got type {}.") + msg = "Expected type str or Iterable for argument classes, " "but got type {}." raise ValueError(msg.format(type(classes))) classes = list(classes) - msg_fmtstr_type = ("Expected type str for elements in argument classes, " - "but got type {}.") + msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}." for c in classes: verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) - c_short = c.split('_') - category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] + c_short = c.split("_") + category, dset_opt = "_".join(c_short[:-1]), c_short[-1] msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." - msg = msg_fmtstr.format(category, "LSUN class", - iterable_to_str(categories)) + msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories)) verify_str_arg(category, valid_values=categories, custom_msg=msg) msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index edd2185c984..b23f7831bf3 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -1,16 +1,19 @@ -from .vision import VisionDataset -import warnings -from PIL import Image +import codecs import os import os.path -import numpy as np -import torch -import codecs +import shutil import string +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError -from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity -import shutil + +import numpy as np +from PIL import Image + +import torch + +from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .vision import VisionDataset class MNIST(VisionDataset): @@ -31,21 +34,31 @@ class MNIST(VisionDataset): """ mirrors = [ - 'http://yann.lecun.com/exdb/mnist/', - 'https://ossci-datasets.s3.amazonaws.com/mnist/', + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", ] resources = [ ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), - ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), ] - training_file = 'training.pt' - test_file = 'test.pt' - classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', - '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] @property def train_labels(self): @@ -68,15 +81,14 @@ def test_data(self): return self.data def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(MNIST, self).__init__(root, transform=transform, - target_transform=target_transform) + super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform) torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}") self.train = train # training set or test set @@ -88,8 +100,7 @@ def __init__( self.download() if not self._check_exists(): - raise RuntimeError('Dataset not found.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found." + " You can use download=True to download it") self.data, self.targets = self._load_data() @@ -129,7 +140,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img.numpy(), mode='L') + img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) @@ -144,11 +155,11 @@ def __len__(self) -> int: @property def raw_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, 'raw') + return os.path.join(self.root, self.__class__.__name__, "raw") @property def processed_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, 'processed') + return os.path.join(self.root, self.__class__.__name__, "processed") @property def class_to_idx(self) -> Dict[str, int]: @@ -174,15 +185,9 @@ def download(self) -> None: url = "{}{}".format(mirror, filename) try: print("Downloading {}".format(url)) - download_and_extract_archive( - url, download_root=self.raw_folder, - filename=filename, - md5=md5 - ) + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) except URLError as error: - print( - "Failed to download (trying next):\n{}".format(error) - ) + print("Failed to download (trying next):\n{}".format(error)) continue finally: print() @@ -210,18 +215,16 @@ class FashionMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - mirrors = [ - "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" - ] + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] resources = [ ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), - ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), ] - classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', - 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] class KMNIST(MNIST): @@ -240,17 +243,16 @@ class KMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - mirrors = [ - "http://codh.rois.ac.jp/kmnist/dataset/kmnist/" - ] + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] resources = [ ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), - ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134") + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), ] - classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo'] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] class EMNIST(MNIST): @@ -272,19 +274,20 @@ class EMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - url = 'https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' + + url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" md5 = "58c8d27c78d21e728a6bc7b3cc06412e" - splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") # Merged Classes assumes Same structure for both uppercase and lowercase version - _merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'} + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} _all_classes = set(string.digits + string.ascii_letters) classes_split_dict = { - 'byclass': sorted(list(_all_classes)), - 'bymerge': sorted(list(_all_classes - _merged_classes)), - 'balanced': sorted(list(_all_classes - _merged_classes)), - 'letters': ['N/A'] + list(string.ascii_lowercase), - 'digits': list(string.digits), - 'mnist': list(string.digits), + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), } def __init__(self, root: str, split: str, **kwargs: Any) -> None: @@ -296,11 +299,11 @@ def __init__(self, root: str, split: str, **kwargs: Any) -> None: @staticmethod def _training_file(split) -> str: - return 'training_{}.pt'.format(split) + return "training_{}.pt".format(split) @staticmethod def _test_file(split) -> str: - return 'test_{}.pt'.format(split) + return "test_{}.pt".format(split) @property def _file_prefix(self) -> str: @@ -329,9 +332,9 @@ def download(self) -> None: os.makedirs(self.raw_folder, exist_ok=True) download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) - gzip_folder = os.path.join(self.raw_folder, 'gzip') + gzip_folder = os.path.join(self.raw_folder, "gzip") for gzip_file in os.listdir(gzip_folder): - if gzip_file.endswith('.gz'): + if gzip_file.endswith(".gz"): extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) shutil.rmtree(gzip_folder) @@ -366,39 +369,60 @@ class QMNIST(MNIST): training set ot the testing set. Default: True. """ - subsets = { - 'train': 'train', - 'test': 'test', - 'test10k': 'test', - 'test50k': 'test', - 'nist': 'nist' - } + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] - 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', - 'ed72d4157d28c017586c42bc6afe6370'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', - '0058f8dd561b90ffdd0f734c6a30e5e4')], - 'test': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz', - '1394631089c404de565df7b7aeaf9412'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz', - '5b5b05890a5e13444e108efe57b788aa')], - 'nist': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz', - '7f124b3b8ab81486c9d8c2749c17f834'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz', - '5ed0e788978e45d4a8bd4b7caec3d79d')] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], } - classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', - '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] def __init__( - self, root: str, what: Optional[str] = None, compat: bool = True, - train: bool = True, **kwargs: Any + self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any ) -> None: if what is None: - what = 'train' if train else 'test' + what = "train" if train else "test" self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) self.compat = compat - self.data_file = what + '.pt' + self.data_file = what + ".pt" self.training_file = self.data_file self.test_file = self.data_file super(QMNIST, self).__init__(root, train, **kwargs) @@ -418,16 +442,16 @@ def _check_exists(self) -> bool: def _load_data(self): data = read_sn3_pascalvincent_tensor(self.images_file) - assert (data.dtype == torch.uint8) - assert (data.ndimension() == 3) + assert data.dtype == torch.uint8 + assert data.ndimension() == 3 targets = read_sn3_pascalvincent_tensor(self.labels_file).long() - assert (targets.ndimension() == 2) + assert targets.ndimension() == 2 - if self.what == 'test10k': + if self.what == "test10k": data = data[0:10000, :, :].clone() targets = targets[0:10000, :].clone() - elif self.what == 'test50k': + elif self.what == "test50k": data = data[10000:, :, :].clone() targets = targets[10000:, :].clone() @@ -435,7 +459,7 @@ def _load_data(self): def download(self) -> None: """Download the QMNIST data if it doesn't exist already. - Note that we only download what has been asked for (argument 'what'). + Note that we only download what has been asked for (argument 'what'). """ if self._check_exists(): return @@ -444,7 +468,7 @@ def download(self) -> None: split = self.resources[self.subsets[self.what]] for url, md5 in split: - filename = url.rpartition('/')[2] + filename = url.rpartition("/")[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) @@ -452,7 +476,7 @@ def download(self) -> None: def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag img, target = self.data[index], self.targets[index] - img = Image.fromarray(img.numpy(), mode='L') + img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) if self.compat: @@ -466,22 +490,22 @@ def extra_repr(self) -> str: def get_int(b: bytes) -> int: - return int(codecs.encode(b, 'hex'), 16) + return int(codecs.encode(b, "hex"), 16) SN3_PASCALVINCENT_TYPEMAP = { 8: (torch.uint8, np.uint8, np.uint8), 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype('>i2'), 'i2'), - 12: (torch.int32, np.dtype('>i4'), 'i4'), - 13: (torch.float32, np.dtype('>f4'), 'f4'), - 14: (torch.float64, np.dtype('>f8'), 'f8') + 11: (torch.int16, np.dtype(">i2"), "i2"), + 12: (torch.int32, np.dtype(">i4"), "i4"), + 13: (torch.float32, np.dtype(">f4"), "f4"), + 14: (torch.float64, np.dtype(">f8"), "f8"), } def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). - Argument may be a filename, compressed filename, or file object. + Argument may be a filename, compressed filename, or file object. """ # read with open(path, "rb") as f: @@ -493,7 +517,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso assert 1 <= nd <= 3 assert 8 <= ty <= 14 m = SN3_PASCALVINCENT_TYPEMAP[ty] - s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] + s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) @@ -501,13 +525,13 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso def read_label_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 1) + assert x.dtype == torch.uint8 + assert x.ndimension() == 1 return x.long() def read_image_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 3) + assert x.dtype == torch.uint8 + assert x.ndimension() == 3 return x diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index b78bf86d16f..a96b5612fd7 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,8 +1,10 @@ -from PIL import Image from os.path import join from typing import Any, Callable, List, Optional, Tuple + +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, list_dir, list_files from .vision import VisionDataset -from .utils import download_and_extract_archive, check_integrity, list_dir, list_files class Omniglot(VisionDataset): @@ -21,38 +23,40 @@ class Omniglot(VisionDataset): puts it in root directory. If the zip files are already downloaded, they are not downloaded again. """ - folder = 'omniglot-py' - download_url_prefix = 'https://raw.githubusercontent.com/brendenlake/omniglot/master/python' + + folder = "omniglot-py" + download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python" zips_md5 = { - 'images_background': '68d2efa1b9178cc56df9314c21c6e718', - 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' + "images_background": "68d2efa1b9178cc56df9314c21c6e718", + "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811", } def __init__( - self, - root: str, - background: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + background: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Omniglot, self).__init__(join(root, self.folder), transform=transform, - target_transform=target_transform) + super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform) self.background = background if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.target_folder = join(self.root, self._get_target_folder()) self._alphabets = list_dir(self.target_folder) - self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] - for a in self._alphabets], []) - self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] - for idx, character in enumerate(self._characters)] + self._characters: List[str] = sum( + [[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], [] + ) + self._character_images = [ + [(image, idx) for image in list_files(join(self.target_folder, character), ".png")] + for idx, character in enumerate(self._characters) + ] self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) def __len__(self) -> int: @@ -68,7 +72,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ image_name, character_class = self._flat_character_images[index] image_path = join(self.target_folder, self._characters[character_class], image_name) - image = Image.open(image_path, mode='r').convert('L') + image = Image.open(image_path, mode="r").convert("L") if self.transform: image = self.transform(image) @@ -80,19 +84,19 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: def _check_integrity(self) -> bool: zip_filename = self._get_target_folder() - if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): + if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]): return False return True def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return filename = self._get_target_folder() - zip_filename = filename + '.zip' - url = self.download_url_prefix + '/' + zip_filename + zip_filename = filename + ".zip" + url = self.download_url_prefix + "/" + zip_filename download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) def _get_target_folder(self) -> str: - return 'images_background' if self.background else 'images_evaluation' + return "images_background" if self.background else "images_evaluation" diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index abb89701e1e..f74a5451187 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -1,12 +1,13 @@ import os +from typing import Any, Callable, List, Optional, Tuple, Union + import numpy as np from PIL import Image -from typing import Any, Callable, List, Optional, Tuple, Union import torch -from .vision import VisionDataset from .utils import download_url +from .vision import VisionDataset class PhotoTour(VisionDataset): @@ -33,56 +34,67 @@ class PhotoTour(VisionDataset): downloaded again. """ + urls = { - 'notredame_harris': [ - 'http://matthewalunbrown.com/patchdata/notredame_harris.zip', - 'notredame_harris.zip', - '69f8c90f78e171349abdf0307afefe4d' - ], - 'yosemite_harris': [ - 'http://matthewalunbrown.com/patchdata/yosemite_harris.zip', - 'yosemite_harris.zip', - 'a73253d1c6fbd3ba2613c45065c00d46' - ], - 'liberty_harris': [ - 'http://matthewalunbrown.com/patchdata/liberty_harris.zip', - 'liberty_harris.zip', - 'c731fcfb3abb4091110d0ae8c7ba182c' + "notredame_harris": [ + "http://matthewalunbrown.com/patchdata/notredame_harris.zip", + "notredame_harris.zip", + "69f8c90f78e171349abdf0307afefe4d", ], - 'notredame': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip', - 'notredame.zip', - '509eda8535847b8c0a90bbb210c83484' + "yosemite_harris": [ + "http://matthewalunbrown.com/patchdata/yosemite_harris.zip", + "yosemite_harris.zip", + "a73253d1c6fbd3ba2613c45065c00d46", ], - 'yosemite': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip', - 'yosemite.zip', - '533b2e8eb7ede31be40abc317b2fd4f0' + "liberty_harris": [ + "http://matthewalunbrown.com/patchdata/liberty_harris.zip", + "liberty_harris.zip", + "c731fcfb3abb4091110d0ae8c7ba182c", ], - 'liberty': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip', - 'liberty.zip', - 'fdd9152f138ea5ef2091746689176414' + "notredame": [ + "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip", + "notredame.zip", + "509eda8535847b8c0a90bbb210c83484", ], + "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"], + "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"], } - means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, - 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} - stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, - 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} - lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, - 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295} - image_ext = 'bmp' - info_file = 'info.txt' - matches_files = 'm50_100000_100000_0.txt' + means = { + "notredame": 0.4854, + "yosemite": 0.4844, + "liberty": 0.4437, + "notredame_harris": 0.4854, + "yosemite_harris": 0.4844, + "liberty_harris": 0.4437, + } + stds = { + "notredame": 0.1864, + "yosemite": 0.1818, + "liberty": 0.2019, + "notredame_harris": 0.1864, + "yosemite_harris": 0.1818, + "liberty_harris": 0.2019, + } + lens = { + "notredame": 468159, + "yosemite": 633587, + "liberty": 450092, + "liberty_harris": 379587, + "yosemite_harris": 450912, + "notredame_harris": 325295, + } + image_ext = "bmp" + info_file = "info.txt" + matches_files = "m50_100000_100000_0.txt" def __init__( - self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False + self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False ) -> None: super(PhotoTour, self).__init__(root, transform=transform) self.name = name self.data_dir = os.path.join(self.root, name) - self.data_down = os.path.join(self.root, '{}.zip'.format(name)) - self.data_file = os.path.join(self.root, '{}.pt'.format(name)) + self.data_down = os.path.join(self.root, "{}.zip".format(name)) + self.data_file = os.path.join(self.root, "{}.pt".format(name)) self.train = train self.mean = self.means[name] @@ -128,7 +140,7 @@ def _check_downloaded(self) -> bool: def download(self) -> None: if self._check_datafile_exists(): - print('# Found cached data {}'.format(self.data_file)) + print("# Found cached data {}".format(self.data_file)) return if not self._check_downloaded(): @@ -140,25 +152,26 @@ def download(self) -> None: download_url(url, self.root, filename, md5) - print('# Extracting data {}\n'.format(self.data_down)) + print("# Extracting data {}\n".format(self.data_down)) import zipfile - with zipfile.ZipFile(fpath, 'r') as z: + + with zipfile.ZipFile(fpath, "r") as z: z.extractall(self.data_dir) os.unlink(fpath) def cache(self) -> None: # process and save as torch files - print('# Caching data {}'.format(self.data_file)) + print("# Caching data {}".format(self.data_file)) dataset = ( read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_info_file(self.data_dir, self.info_file), - read_matches_files(self.data_dir, self.matches_files) + read_matches_files(self.data_dir, self.matches_files), ) - with open(self.data_file, 'wb') as f: + with open(self.data_file, "wb") as f: torch.save(dataset, f) def extra_repr(self) -> str: @@ -166,17 +179,14 @@ def extra_repr(self) -> str: def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: - """Return a Tensor containing the patches - """ + """Return a Tensor containing the patches""" def PIL2array(_img: Image.Image) -> np.ndarray: - """Convert PIL image type to numpy 2D array - """ + """Convert PIL image type to numpy 2D array""" return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) def find_files(_data_dir: str, _image_ext: str) -> List[str]: - """Return a list with the file names of the images containing the patches - """ + """Return a list with the file names of the images containing the patches""" files = [] # find those files with the specified extension for file_dir in os.listdir(_data_dir): @@ -198,22 +208,21 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]: def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: """Return a Tensor containing the list of labels - Read the file and keep only the ID of the 3D point. + Read the file and keep only the ID of the 3D point. """ - with open(os.path.join(data_dir, info_file), 'r') as f: + with open(os.path.join(data_dir, info_file), "r") as f: labels = [int(line.split()[0]) for line in f] return torch.LongTensor(labels) def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: """Return a Tensor containing the ground truth matches - Read the file and keep only 3D point ID. - Matches are represented with a 1, non matches with a 0. + Read the file and keep only 3D point ID. + Matches are represented with a 1, non matches with a 0. """ matches = [] - with open(os.path.join(data_dir, matches_file), 'r') as f: + with open(os.path.join(data_dir, matches_file), "r") as f: for line in f: line_split = line.split() - matches.append([int(line_split[0]), int(line_split[3]), - int(line_split[1] == line_split[4])]) + matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) return torch.LongTensor(matches) diff --git a/torchvision/datasets/places365.py b/torchvision/datasets/places365.py index 648e0d604ba..aab1e7b5086 100644 --- a/torchvision/datasets/places365.py +++ b/torchvision/datasets/places365.py @@ -4,7 +4,7 @@ from urllib.parse import urljoin from .folder import default_loader -from .utils import verify_str_arg, check_integrity, download_and_extract_archive +from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/samplers/__init__.py b/torchvision/datasets/samplers/__init__.py index 870322d39b4..58b2d2abd93 100644 --- a/torchvision/datasets/samplers/__init__.py +++ b/torchvision/datasets/samplers/__init__.py @@ -1,3 +1,3 @@ -from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler +from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler -__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler') +__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler") diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index 0f90e3ad1b0..a4a589dfc18 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -1,9 +1,10 @@ import math +from typing import Iterator, List, Optional, Sized, Union, cast + import torch -from torch.utils.data import Sampler import torch.distributed as dist +from torch.utils.data import Sampler from torchvision.datasets.video_utils import VideoClips -from typing import Optional, List, Iterator, Sized, Union, cast class DistributedSampler(Sampler): @@ -36,12 +37,12 @@ class DistributedSampler(Sampler): """ def __init__( - self, - dataset: Sized, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = False, - group_size: int = 1, + self, + dataset: Sized, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + group_size: int = 1, ) -> None: if num_replicas is None: if not dist.is_available(): @@ -51,9 +52,11 @@ def __init__( if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() - assert len(dataset) % group_size == 0, ( - "dataset length must be a multiplier of group size" - "dataset length: %d, group size: %d" % (len(dataset), group_size) + assert ( + len(dataset) % group_size == 0 + ), "dataset length must be a multiplier of group size" "dataset length: %d, group size: %d" % ( + len(dataset), + group_size, ) self.dataset = dataset self.group_size = group_size @@ -61,9 +64,7 @@ def __init__( self.rank = rank self.epoch = 0 dataset_group_length = len(dataset) // group_size - self.num_group_samples = int( - math.ceil(dataset_group_length * 1.0 / self.num_replicas) - ) + self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas)) self.num_samples = self.num_group_samples * group_size self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle @@ -79,16 +80,14 @@ def __iter__(self) -> Iterator[int]: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size total_group_size = self.total_size // self.group_size - indices = torch.reshape( - torch.LongTensor(indices), (total_group_size, self.group_size) - ) + indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size)) # subsample - indices = indices[self.rank:total_group_size:self.num_replicas, :] + indices = indices[self.rank : total_group_size : self.num_replicas, :] indices = torch.reshape(indices, (-1,)).tolist() assert len(indices) == self.num_samples @@ -115,10 +114,10 @@ class UniformClipSampler(Sampler): video_clips (VideoClips): video clips to sample from num_clips_per_video (int): number of clips to be sampled per video """ + def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): - raise TypeError("Expected video_clips to be an instance of VideoClips, " - "got {}".format(type(video_clips))) + raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.num_clips_per_video = num_clips_per_video @@ -132,19 +131,13 @@ def __iter__(self) -> Iterator[int]: # corner case where video decoding fails continue - sampled = ( - torch.linspace(s, s + length - 1, steps=self.num_clips_per_video) - .floor() - .to(torch.int64) - ) + sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64) s += length idxs.append(sampled) return iter(cast(List[int], torch.cat(idxs).tolist())) def __len__(self) -> int: - return sum( - self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0 - ) + return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0) class RandomClipSampler(Sampler): @@ -155,10 +148,10 @@ class RandomClipSampler(Sampler): video_clips (VideoClips): video clips to sample from max_clips_per_video (int): maximum number of clips to be sampled per video """ + def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): - raise TypeError("Expected video_clips to be an instance of VideoClips, " - "got {}".format(type(video_clips))) + raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.max_clips_per_video = max_clips_per_video diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index e47c9493858..2ff6a1a55d9 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -1,12 +1,12 @@ import os import shutil -from .vision import VisionDataset from typing import Any, Callable, Optional, Tuple import numpy as np - from PIL import Image -from .utils import download_url, verify_str_arg, download_and_extract_archive + +from .utils import download_and_extract_archive, download_url, verify_str_arg +from .vision import VisionDataset class SBDataset(VisionDataset): @@ -50,30 +50,29 @@ class SBDataset(VisionDataset): voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722" def __init__( - self, - root: str, - image_set: str = "train", - mode: str = "boundaries", - download: bool = False, - transforms: Optional[Callable] = None, + self, + root: str, + image_set: str = "train", + mode: str = "boundaries", + download: bool = False, + transforms: Optional[Callable] = None, ) -> None: try: from scipy.io import loadmat + self._loadmat = loadmat except ImportError: - raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " - "pip install scipy") + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy") super(SBDataset, self).__init__(root, transforms) - self.image_set = verify_str_arg(image_set, "image_set", - ("train", "val", "train_noval")) + self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval")) self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) self.num_classes = 20 sbd_root = self.root - image_dir = os.path.join(sbd_root, 'img') - mask_dir = os.path.join(sbd_root, 'cls') + image_dir = os.path.join(sbd_root, "img") + mask_dir = os.path.join(sbd_root, "cls") if download: download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) @@ -81,36 +80,35 @@ def __init__( for f in ["cls", "img", "inst", "train.txt", "val.txt"]: old_path = os.path.join(extracted_ds_root, f) shutil.move(old_path, sbd_root) - download_url(self.voc_train_url, sbd_root, self.voc_split_filename, - self.voc_split_md5) + download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5) if not os.path.isdir(sbd_root): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") - split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') + split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f), "r") as fh: file_names = [x.strip() for x in fh.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] - assert (len(self.images) == len(self.masks)) + assert len(self.images) == len(self.masks) - self._get_target = self._get_segmentation_target \ - if self.mode == "segmentation" else self._get_boundaries_target + self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target def _get_segmentation_target(self, filepath: str) -> Image.Image: mat = self._loadmat(filepath) - return Image.fromarray(mat['GTcls'][0]['Segmentation'][0]) + return Image.fromarray(mat["GTcls"][0]["Segmentation"][0]) def _get_boundaries_target(self, filepath: str) -> np.ndarray: mat = self._loadmat(filepath) - return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0) - for i in range(self.num_classes)], axis=0) + return np.concatenate( + [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)], + axis=0, + ) def __getitem__(self, index: int) -> Tuple[Any, Any]: - img = Image.open(self.images[index]).convert('RGB') + img = Image.open(self.images[index]).convert("RGB") target = self._get_target(self.masks[index]) if self.transforms is not None: @@ -123,4 +121,4 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Image set: {image_set}", "Mode: {mode}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index 6c8ad15686b..ab2ac724df1 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -1,8 +1,9 @@ -from PIL import Image -from .utils import download_url, check_integrity +import os from typing import Any, Callable, Optional, Tuple -import os +from PIL import Image + +from .utils import check_integrity, download_url from .vision import VisionDataset @@ -20,38 +21,37 @@ class SBU(VisionDataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ + url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz" - md5_checksum = '9aec147b3488753cf758b4d493422285' + md5_checksum = "9aec147b3488753cf758b4d493422285" def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, ) -> None: - super(SBU, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SBU, self).__init__(root, transform=transform, target_transform=target_transform) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") # Read the caption for each photo self.photos = [] self.captions = [] - file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt') - file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt') + file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt") + file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt") for line1, line2 in zip(open(file1), open(file2)): url = line1.rstrip() photo = os.path.basename(url) - filename = os.path.join(self.root, 'dataset', photo) + filename = os.path.join(self.root, "dataset", photo) if os.path.exists(filename): caption = line2.rstrip() self.photos.append(photo) @@ -65,8 +65,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: Returns: tuple: (image, target) where target is a caption for the photo. """ - filename = os.path.join(self.root, 'dataset', self.photos[index]) - img = Image.open(filename).convert('RGB') + filename = os.path.join(self.root, "dataset", self.photos[index]) + img = Image.open(filename).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -93,21 +93,21 @@ def download(self) -> None: import tarfile if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_url(self.url, self.root, self.filename, self.md5_checksum) # Extract file - with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar: + with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) # Download individual photos - with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh: + with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh: for line in fh: url = line.rstrip() try: - download_url(url, os.path.join(self.root, 'dataset')) + download_url(url, os.path.join(self.root, "dataset")) except OSError: # The images point to public images on Flickr. # Note: Images might be removed by users at anytime. diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 20ce4e5f5d5..724e881f787 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -1,10 +1,12 @@ -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple + +import numpy as np +from PIL import Image + +from .utils import check_integrity, download_url from .vision import VisionDataset -from .utils import download_url, check_integrity class SEMEION(VisionDataset): @@ -24,30 +26,28 @@ class SEMEION(VisionDataset): """ url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" filename = "semeion.data" - md5_checksum = 'cb545d371d2ce14ec121470795a77432' + md5_checksum = "cb545d371d2ce14ec121470795a77432" def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, ) -> None: - super(SEMEION, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SEMEION, self).__init__(root, transform=transform, target_transform=target_transform) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") fp = os.path.join(self.root, self.filename) data = np.loadtxt(fp) # convert value to 8 bit unsigned integer # color (white #255) the pixels - self.data = (data[:, :256] * 255).astype('uint8') + self.data = (data[:, :256] * 255).astype("uint8") self.data = np.reshape(self.data, (-1, 16, 16)) self.labels = np.nonzero(data[:, 256:])[1] @@ -63,7 +63,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img, mode='L') + img = Image.fromarray(img, mode="L") if self.transform is not None: img = self.transform(img) @@ -85,7 +85,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return root = self.root diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 50e9af882bc..20ebbc3b0ee 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -1,11 +1,12 @@ -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset +import numpy as np +from PIL import Image + from .utils import check_integrity, download_and_extract_archive, verify_str_arg +from .vision import VisionDataset class STL10(VisionDataset): @@ -27,70 +28,60 @@ class STL10(VisionDataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ - base_folder = 'stl10_binary' + + base_folder = "stl10_binary" url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" filename = "stl10_binary.tar.gz" - tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb' - class_names_file = 'class_names.txt' - folds_list_file = 'fold_indices.txt' + tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb" + class_names_file = "class_names.txt" + folds_list_file = "fold_indices.txt" train_list = [ - ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], - ['train_y.bin', '5a34089d4802c674881badbb80307741'], - ['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4'] + ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"], + ["train_y.bin", "5a34089d4802c674881badbb80307741"], + ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"], ] - test_list = [ - ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], - ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] - ] - splits = ('train', 'train+unlabeled', 'unlabeled', 'test') + test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]] + splits = ("train", "train+unlabeled", "unlabeled", "test") def __init__( - self, - root: str, - split: str = "train", - folds: Optional[int] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + folds: Optional[int] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(STL10, self).__init__(root, transform=transform, - target_transform=target_transform) + super(STL10, self).__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", self.splits) self.folds = self._verify_folds(folds) if download: self.download() elif not self._check_integrity(): - raise RuntimeError( - 'Dataset not found or corrupted. ' - 'You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted. " "You can use download=True to download it") # now load the picked numpy arrays self.labels: Optional[np.ndarray] - if self.split == 'train': - self.data, self.labels = self.__loadfile( - self.train_list[0][0], self.train_list[1][0]) + if self.split == "train": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.__load_folds(folds) - elif self.split == 'train+unlabeled': - self.data, self.labels = self.__loadfile( - self.train_list[0][0], self.train_list[1][0]) + elif self.split == "train+unlabeled": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.__load_folds(folds) unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) self.data = np.concatenate((self.data, unlabeled_data)) - self.labels = np.concatenate( - (self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) + self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) - elif self.split == 'unlabeled': + elif self.split == "unlabeled": self.data, _ = self.__loadfile(self.train_list[2][0]) self.labels = np.asarray([-1] * self.data.shape[0]) else: # self.split == 'test': - self.data, self.labels = self.__loadfile( - self.test_list[0][0], self.test_list[1][0]) + self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0]) - class_file = os.path.join( - self.root, self.base_folder, self.class_names_file) + class_file = os.path.join(self.root, self.base_folder, self.class_names_file) if os.path.isfile(class_file): with open(class_file) as f: self.classes = f.read().splitlines() @@ -101,8 +92,7 @@ def _verify_folds(self, folds: Optional[int]) -> Optional[int]: elif isinstance(folds, int): if folds in range(10): return folds - msg = ("Value for argument folds should be in the range [0, 10), " - "but got {}.") + msg = "Value for argument folds should be in the range [0, 10), " "but got {}." raise ValueError(msg.format(folds)) else: msg = "Expected type None or int for argument folds, but got type {}." @@ -140,13 +130,12 @@ def __len__(self) -> int: def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: labels = None if labels_file: - path_to_labels = os.path.join( - self.root, self.base_folder, labels_file) - with open(path_to_labels, 'rb') as f: + path_to_labels = os.path.join(self.root, self.base_folder, labels_file) + with open(path_to_labels, "rb") as f: labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based path_to_data = os.path.join(self.root, self.base_folder, data_file) - with open(path_to_data, 'rb') as f: + with open(path_to_data, "rb") as f: # read whole file in uint8 chunks everything = np.fromfile(f, dtype=np.uint8) images = np.reshape(everything, (-1, 3, 96, 96)) @@ -156,7 +145,7 @@ def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple def _check_integrity(self) -> bool: root = self.root - for fentry in (self.train_list + self.test_list): + for fentry in self.train_list + self.test_list: filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): @@ -165,7 +154,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) self._check_integrity() @@ -177,11 +166,10 @@ def __load_folds(self, folds: Optional[int]) -> None: # loads one of the folds if specified if folds is None: return - path_to_folds = os.path.join( - self.root, self.base_folder, self.folds_list_file) - with open(path_to_folds, 'r') as f: + path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file) + with open(path_to_folds, "r") as f: str_idx = f.read().splitlines()[folds] - list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ') + list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ") self.data = self.data[list_idx, :, :, :] if self.labels is not None: self.labels = self.labels[list_idx] diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index f1adee687eb..5e43367bfa6 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -1,10 +1,12 @@ -from .vision import VisionDataset -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple -from .utils import download_url, check_integrity, verify_str_arg + +import numpy as np +from PIL import Image + +from .utils import check_integrity, download_url, verify_str_arg +from .vision import VisionDataset class SVHN(VisionDataset): @@ -33,23 +35,32 @@ class SVHN(VisionDataset): """ split_list = { - 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", - "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], - 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", - "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], - 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", - "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} + "train": [ + "http://ufldl.stanford.edu/housenumbers/train_32x32.mat", + "train_32x32.mat", + "e26dedcc434d2e4c54c9b2d4a06d8373", + ], + "test": [ + "http://ufldl.stanford.edu/housenumbers/test_32x32.mat", + "test_32x32.mat", + "eb5a983be6a315427106f1b164d9cef3", + ], + "extra": [ + "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", + "extra_32x32.mat", + "a93ce644f1a588dc4d68dda5feec44a7", + ], + } def __init__( - self, - root: str, - split: str = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(SVHN, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.url = self.split_list[split][0] self.filename = self.split_list[split][1] @@ -59,8 +70,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") # import here rather than at top of file because this is # an optional dependency for torchvision @@ -69,12 +79,12 @@ def __init__( # reading(loading) mat file as array loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) - self.data = loaded_mat['X'] + self.data = loaded_mat["X"] # loading from the .mat file gives an np array of type np.uint8 # converting to np.int64, so that we have a LongTensor after # the conversion from the numpy array # the squeeze is needed to obtain a 1D tensor - self.labels = loaded_mat['y'].astype(np.int64).squeeze() + self.labels = loaded_mat["y"].astype(np.int64).squeeze() # the svhn dataset assigns the class label "10" to the digit 0 # this makes it inconsistent with several loss functions diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 71f62257bcb..f3fdfe96533 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -42,15 +42,28 @@ class UCF101(VisionDataset): - label (int): class of the video clip """ - def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, - frame_rate=None, fold=1, train=True, transform=None, - _precomputed_metadata=None, num_workers=1, _video_width=0, - _video_height=0, _video_min_dimension=0, _audio_samples=0): + def __init__( + self, + root, + annotation_path, + frames_per_clip, + step_between_clips=1, + frame_rate=None, + fold=1, + train=True, + transform=None, + _precomputed_metadata=None, + num_workers=1, + _video_width=0, + _video_height=0, + _video_min_dimension=0, + _audio_samples=0, + ): super(UCF101, self).__init__(root) if not 1 <= fold <= 3: raise ValueError("fold should be between 1 and 3, got {}".format(fold)) - extensions = ('avi',) + extensions = ("avi",) self.fold = fold self.train = train diff --git a/torchvision/datasets/usps.py b/torchvision/datasets/usps.py index c315b8d3111..f41c6695f2b 100644 --- a/torchvision/datasets/usps.py +++ b/torchvision/datasets/usps.py @@ -1,7 +1,8 @@ -from PIL import Image import os +from typing import Any, Callable, Optional, Tuple, cast + import numpy as np -from typing import Any, Callable, cast, Optional, Tuple +from PIL import Image from .utils import download_url from .vision import VisionDataset @@ -26,28 +27,30 @@ class USPS(VisionDataset): downloaded again. """ + split_list = { - 'train': [ + "train": [ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", - "usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197' + "usps.bz2", + "ec16c51db3855ca6c91edd34d0e9b197", ], - 'test': [ + "test": [ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2", - "usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118' + "usps.t.bz2", + "8ea070ee2aca1ac39742fdd1ef5ed118", ], } def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(USPS, self).__init__(root, transform=transform, - target_transform=target_transform) - split = 'train' if train else 'test' + super(USPS, self).__init__(root, transform=transform, target_transform=target_transform) + split = "train" if train else "test" url, filename, checksum = self.split_list[split] full_path = os.path.join(self.root, filename) @@ -55,9 +58,10 @@ def __init__( download_url(url, self.root, filename, md5=checksum) import bz2 + with bz2.open(full_path) as fp: raw_data = [line.decode().split() for line in fp.readlines()] - tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] + tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) targets = [int(d[0]) - 1 for d in raw_data] @@ -77,7 +81,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img, mode='L') + img = Image.fromarray(img, mode="L") if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 9ae726edd8f..eb1d108c0a8 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,28 +1,24 @@ import bz2 +import gzip +import hashlib +import itertools +import lzma import os import os.path -import hashlib -import gzip +import pathlib import re import tarfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator -from urllib.parse import urlparse -import zipfile -import lzma import urllib -import urllib.request import urllib.error -import pathlib -import itertools +import urllib.request +import zipfile +from typing import IO, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar +from urllib.parse import urlparse import torch from torch.utils.model_zoo import tqdm -from .._internally_replaced_utils import ( - _download_file_from_remote_location, - _is_remote_location_available, -) - +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available USER_AGENT = "pytorch/vision" @@ -52,8 +48,8 @@ def bar_update(count, block_size, total_size): def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: md5 = hashlib.md5() - with open(fpath, 'rb') as f: - for chunk in iter(lambda: f.read(chunk_size), b''): + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): md5.update(chunk) return md5.hexdigest() @@ -120,7 +116,7 @@ def download_url( # check if file is already present locally if check_integrity(fpath, md5): - print('Using downloaded and verified file: ' + fpath) + print("Using downloaded and verified file: " + fpath) return if _is_remote_location_available(): @@ -136,13 +132,12 @@ def download_url( # download the file try: - print('Downloading ' + url + ' to ' + fpath) + print("Downloading " + url + " to " + fpath) _urlretrieve(url, fpath) except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] - if url[:5] == 'https': - url = url.replace('https:', 'http:') - print('Failed download. Trying https -> http instead.' - ' Downloading ' + url + ' to ' + fpath) + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath) _urlretrieve(url, fpath) else: raise e @@ -202,6 +197,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ """ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url import requests + url = "https://docs.google.com/uc?export=download" root = os.path.expanduser(root) @@ -212,15 +208,15 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ os.makedirs(root, exist_ok=True) if os.path.isfile(fpath) and check_integrity(fpath, md5): - print('Using downloaded and verified file: ' + fpath) + print("Using downloaded and verified file: " + fpath) else: session = requests.Session() - response = session.get(url, params={'id': file_id}, stream=True) + response = session.get(url, params={"id": file_id}, stream=True) token = _get_confirm_token(response) if token: - params = {'id': file_id, 'confirm': token} + params = {"id": file_id, "confirm": token} response = session.get(url, params=params, stream=True) # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent @@ -240,20 +236,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ) raise RuntimeError(msg) - _save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath) + _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) response.close() def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] for key, value in response.cookies.items(): - if key.startswith('download_warning'): + if key.startswith("download_warning"): return value return None def _save_response_content( - response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined] + response_gen: Iterator[bytes], + destination: str, # type: ignore[name-defined] ) -> None: with open(destination, "wb") as f: pbar = tqdm(total=None) @@ -439,7 +436,10 @@ def iterable_to_str(iterable: Iterable) -> str: def verify_str_arg( - value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, + value: T, + arg: Optional[str] = None, + valid_values: Iterable[T] = None, + custom_msg: Optional[str] = None, ) -> T: if not isinstance(value, torch._six.string_classes): if arg is None: @@ -456,10 +456,8 @@ def verify_str_arg( if custom_msg is not None: msg = custom_msg else: - msg = ("Unknown value '{value}' for argument {arg}. " - "Valid values are {{{valid_values}}}.") - msg = msg.format(value=value, arg=arg, - valid_values=iterable_to_str(valid_values)) + msg = "Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) raise ValueError(msg) return value diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 987270c4cd4..26a1d804c41 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -5,12 +5,7 @@ from typing import List import torch -from torchvision.io import ( - _probe_video_from_file, - _read_video_from_file, - read_video, - read_video_timestamps, -) +from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps from .utils import tqdm @@ -206,14 +201,14 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): if frame_rate is None: frame_rate = fps total_frames = len(video_pts) * (float(frame_rate) / fps) - idxs = VideoClips._resample_video_idx( - int(math.floor(total_frames)), fps, frame_rate - ) + idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) video_pts = video_pts[idxs] clips = unfold(video_pts, num_frames, step) if not clips.numel(): - warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and " - "frames between clips. The video (and potentially others) will be skipped.") + warnings.warn( + "There aren't enough frames in the current video to get a clip for the given clip length and " + "frames between clips. The video (and potentially others) will be skipped." + ) if isinstance(idxs, slice): idxs = [idxs] * len(clips) else: @@ -237,9 +232,7 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.clips = [] self.resampling_idxs = [] for video_pts, fps in zip(self.video_pts, self.video_fps): - clips, idxs = self.compute_clips_for_video( - video_pts, num_frames, step, fps, frame_rate - ) + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) self.clips.append(clips) self.resampling_idxs.append(idxs) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) @@ -295,10 +288,7 @@ def get_clip(self, idx): video_idx (int): index of the video in `video_paths` """ if idx >= self.num_clips(): - raise IndexError( - "Index {} out of range " - "({} number of clips)".format(idx, self.num_clips()) - ) + raise IndexError("Index {} out of range " "({} number of clips)".format(idx, self.num_clips())) video_idx, clip_idx = self.get_clip_location(idx) video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] @@ -314,13 +304,9 @@ def get_clip(self, idx): if self._video_height != 0: raise ValueError("pyav backend doesn't support _video_height != 0") if self._video_min_dimension != 0: - raise ValueError( - "pyav backend doesn't support _video_min_dimension != 0" - ) + raise ValueError("pyav backend doesn't support _video_min_dimension != 0") if self._video_max_dimension != 0: - raise ValueError( - "pyav backend doesn't support _video_max_dimension != 0" - ) + raise ValueError("pyav backend doesn't support _video_max_dimension != 0") if self._audio_samples != 0: raise ValueError("pyav backend doesn't support _audio_samples != 0") @@ -338,19 +324,11 @@ def get_clip(self, idx): audio_start_pts, audio_end_pts = 0, -1 audio_timebase = Fraction(0, 1) - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) if info.has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) - audio_start_pts = pts_convert( - video_start_pts, video_timebase, audio_timebase, math.floor - ) - audio_end_pts = pts_convert( - video_end_pts, video_timebase, audio_timebase, math.ceil - ) + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) + audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) + audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) audio_fps = info.audio_sample_rate video, audio, info = _read_video_from_file( video_path, @@ -376,9 +354,7 @@ def get_clip(self, idx): resampling_idx = resampling_idx - resampling_idx[0] video = video[resampling_idx] info["video_fps"] = self.frame_rate - assert len(video) == self.num_frames, "{} x {}".format( - video.shape, self.num_frames - ) + assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) return video, audio, info, video_idx def __getstate__(self): diff --git a/torchvision/datasets/vision.py b/torchvision/datasets/vision.py index 2cc9ce14cb1..afd7ffc4b83 100644 --- a/torchvision/datasets/vision.py +++ b/torchvision/datasets/vision.py @@ -1,7 +1,8 @@ import os +from typing import Any, Callable, List, Optional, Tuple + import torch import torch.utils.data as data -from typing import Any, Callable, List, Optional, Tuple class VisionDataset(data.Dataset): @@ -22,14 +23,15 @@ class VisionDataset(data.Dataset): :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. """ + _repr_indent = 4 def __init__( - self, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) @@ -38,8 +40,7 @@ def __init__( has_transforms = transforms is not None has_separate_transform = transform is not None or target_transform is not None if has_transforms and has_separate_transform: - raise ValueError("Only transforms or transform/target_transform can " - "be passed as argument") + raise ValueError("Only transforms or transform/target_transform can " "be passed as argument") # for backwards-compatibility self.transform = transform @@ -71,12 +72,11 @@ def __repr__(self) -> str: if hasattr(self, "transforms") and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] - return '\n'.join(lines) + return "\n".join(lines) def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] def extra_repr(self) -> str: return "" @@ -96,16 +96,13 @@ def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] def __repr__(self) -> str: body = [self.__class__.__name__] if self.transform is not None: - body += self._format_transform_repr(self.transform, - "Transform: ") + body += self._format_transform_repr(self.transform, "Transform: ") if self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, - "Target transform: ") + body += self._format_transform_repr(self.target_transform, "Target transform: ") - return '\n'.join(body) + return "\n".join(body) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 56bd92c7972..bf1fb7ba86b 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -1,59 +1,62 @@ -import os import collections -from .vision import VisionDataset +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple from xml.etree.ElementTree import Element as ET_Element + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + try: from defusedxml.ElementTree import parse as ET_parse except ImportError: from xml.etree.ElementTree import parse as ET_parse -from PIL import Image -from typing import Any, Callable, Dict, Optional, Tuple, List -from .utils import download_and_extract_archive, verify_str_arg -import warnings DATASET_YEAR_DICT = { - '2012': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', - 'filename': 'VOCtrainval_11-May-2012.tar', - 'md5': '6cd6e144f989b92b3379bac3b3de84fd', - 'base_dir': os.path.join('VOCdevkit', 'VOC2012') + "2012": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "6cd6e144f989b92b3379bac3b3de84fd", + "base_dir": os.path.join("VOCdevkit", "VOC2012"), + }, + "2011": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar", + "filename": "VOCtrainval_25-May-2011.tar", + "md5": "6c3384ef61512963050cb5d687e5bf1e", + "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"), }, - '2011': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', - 'filename': 'VOCtrainval_25-May-2011.tar', - 'md5': '6c3384ef61512963050cb5d687e5bf1e', - 'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011') + "2010": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar", + "filename": "VOCtrainval_03-May-2010.tar", + "md5": "da459979d0c395079b5c75ee67908abb", + "base_dir": os.path.join("VOCdevkit", "VOC2010"), }, - '2010': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', - 'filename': 'VOCtrainval_03-May-2010.tar', - 'md5': 'da459979d0c395079b5c75ee67908abb', - 'base_dir': os.path.join('VOCdevkit', 'VOC2010') + "2009": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar", + "filename": "VOCtrainval_11-May-2009.tar", + "md5": "59065e4b188729180974ef6572f6a212", + "base_dir": os.path.join("VOCdevkit", "VOC2009"), }, - '2009': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', - 'filename': 'VOCtrainval_11-May-2009.tar', - 'md5': '59065e4b188729180974ef6572f6a212', - 'base_dir': os.path.join('VOCdevkit', 'VOC2009') + "2008": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "2629fa636546599198acfcfbfcf1904a", + "base_dir": os.path.join("VOCdevkit", "VOC2008"), }, - '2008': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', - 'filename': 'VOCtrainval_11-May-2012.tar', - 'md5': '2629fa636546599198acfcfbfcf1904a', - 'base_dir': os.path.join('VOCdevkit', 'VOC2008') + "2007": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", + "filename": "VOCtrainval_06-Nov-2007.tar", + "md5": "c52e279531787c972589f7e41ab4ae64", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, - '2007': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', - 'filename': 'VOCtrainval_06-Nov-2007.tar', - 'md5': 'c52e279531787c972589f7e41ab4ae64', - 'base_dir': os.path.join('VOCdevkit', 'VOC2007') + "2007-test": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", + "filename": "VOCtest_06-Nov-2007.tar", + "md5": "b6e924de25625d8de591ea690078ad9f", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, - '2007-test': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', - 'filename': 'VOCtest_06-Nov-2007.tar', - 'md5': 'b6e924de25625d8de591ea690078ad9f', - 'base_dir': os.path.join('VOCdevkit', 'VOC2007') - } } diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index c1775309b29..3b7d6a77f91 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -1,10 +1,12 @@ -from PIL import Image import os from os.path import abspath, expanduser +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from PIL import Image + import torch -from typing import Any, Callable, List, Dict, Optional, Tuple, Union -from .utils import check_integrity, download_file_from_google_drive, \ - download_and_extract_archive, extract_archive, verify_str_arg + +from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg from .vision import VisionDataset @@ -40,25 +42,25 @@ class WIDERFace(VisionDataset): # File ID MD5 Hash Filename ("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"), ("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"), - ("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip") + ("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"), ] ANNOTATIONS_FILE = ( "http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip", "0e3767bcf0e326556d407bf5bff5d27c", - "wider_face_split.zip" + "wider_face_split.zip", ) def __init__( - self, - root: str, - split: str = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER), - transform=transform, - target_transform=target_transform) + super(WIDERFace, self).__init__( + root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform + ) # check arguments self.split = verify_str_arg(split, "split", ("train", "val", "test")) @@ -66,8 +68,9 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError("Dataset not found or corrupted. " + - "You can use download=True to download and prepare it") + raise RuntimeError( + "Dataset not found or corrupted. " + "You can use download=True to download and prepare it" + ) self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = [] if self.split in ("train", "val"): @@ -102,7 +105,7 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Split: {split}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) def parse_train_val_annotations_file(self) -> None: filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt" @@ -133,16 +136,20 @@ def parse_train_val_annotations_file(self) -> None: box_annotation_line = False file_name_line = True labels_tensor = torch.tensor(labels) - self.img_info.append({ - "img_path": img_path, - "annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height - "blur": labels_tensor[:, 4], - "expression": labels_tensor[:, 5], - "illumination": labels_tensor[:, 6], - "occlusion": labels_tensor[:, 7], - "pose": labels_tensor[:, 8], - "invalid": labels_tensor[:, 9]} - }) + self.img_info.append( + { + "img_path": img_path, + "annotations": { + "bbox": labels_tensor[:, 0:4], # x, y, width, height + "blur": labels_tensor[:, 4], + "expression": labels_tensor[:, 5], + "illumination": labels_tensor[:, 6], + "occlusion": labels_tensor[:, 7], + "pose": labels_tensor[:, 8], + "invalid": labels_tensor[:, 9], + }, + } + ) box_counter = 0 labels.clear() else: @@ -172,7 +179,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return # download and extract image data @@ -182,6 +189,6 @@ def download(self) -> None: extract_archive(filepath) # download and extract annotation files - download_and_extract_archive(url=self.ANNOTATIONS_FILE[0], - download_root=self.root, - md5=self.ANNOTATIONS_FILE[1]) + download_and_extract_archive( + url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1] + ) diff --git a/torchvision/extension.py b/torchvision/extension.py index bea6db33636..a1697deb4cf 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -2,7 +2,6 @@ from ._internally_replaced_utils import _get_extension_path - _HAS_OPS = False @@ -11,12 +10,14 @@ def _has_ops(): try: - lib_path = _get_extension_path('_C') + lib_path = _get_extension_path("_C") torch.ops.load_library(lib_path) _HAS_OPS = True def _has_ops(): # noqa: F811 return True + + except (ImportError, OSError): pass @@ -41,6 +42,7 @@ def _check_cuda_version(): if not _HAS_OPS: return -1 import torch + _version = torch.ops.torchvision._cuda_version() if _version != -1 and torch.version.cuda is not None: tv_version = str(_version) @@ -51,14 +53,17 @@ def _check_cuda_version(): tv_major = int(tv_version[0:2]) tv_minor = int(tv_version[3]) t_version = torch.version.cuda - t_version = t_version.split('.') + t_version = t_version.split(".") t_major = int(t_version[0]) t_minor = int(t_version[1]) if t_major != tv_major or t_minor != tv_minor: - raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. " - "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " - "Please reinstall the torchvision that matches your PyTorch install." - .format(t_major, t_minor, tv_major, tv_minor)) + raise RuntimeError( + "Detected that PyTorch and torchvision were compiled with different CUDA versions. " + "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " + "Please reinstall the torchvision that matches your PyTorch install.".format( + t_major, t_minor, tv_major, tv_minor + ) + ) return _version diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 742344e6b0f..6ce6e9e11f5 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,9 +1,9 @@ import torch from ._video_opt import ( + _HAS_VIDEO_OPT, Timebase, VideoMetaData, - _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, _read_video_from_file, @@ -11,11 +11,6 @@ _read_video_timestamps_from_file, _read_video_timestamps_from_memory, ) -from .video import ( - read_video, - read_video_timestamps, - write_video, -) from .image import ( ImageReadMode, decode_image, @@ -29,7 +24,7 @@ write_jpeg, write_png, ) - +from .video import read_video, read_video_timestamps, write_video if _HAS_VIDEO_OPT: diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index a4a811dec4b..c5678d6ec9a 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,18 +1,16 @@ - import math -import os import warnings from fractions import Fraction from typing import List, Tuple import numpy as np + import torch from .._internally_replaced_utils import _get_extension_path - try: - lib_path = _get_extension_path('video_reader') + lib_path = _get_extension_path("video_reader") torch.ops.load_library(lib_path) _HAS_VIDEO_OPT = True except (ImportError, OSError): @@ -90,9 +88,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): """ meta = VideoMetaData() if vtimebase.numel() > 0: - meta.video_timebase = Timebase( - int(vtimebase[0].item()), int(vtimebase[1].item()) - ) + meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item())) timebase = vtimebase[0].item() / float(vtimebase[1].item()) if vduration.numel() > 0: meta.has_video = True @@ -100,9 +96,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): if vfps.numel() > 0: meta.video_fps = float(vfps.item()) if atimebase.numel() > 0: - meta.audio_timebase = Timebase( - int(atimebase[0].item()), int(atimebase[1].item()) - ) + meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item())) timebase = atimebase[0].item() / float(atimebase[1].item()) if aduration.numel() > 0: meta.has_audio = True @@ -216,10 +210,7 @@ def _read_video_from_file( audio_timebase.numerator, audio_timebase.denominator, ) - vframes, _vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) if aframes.numel() > 0: # when audio stream is found @@ -254,8 +245,7 @@ def _read_video_timestamps_from_file(filename): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, vduration, \ - _aframes, aframe_pts, atimebase, asample_rate, aduration = result + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() @@ -372,10 +362,7 @@ def _read_video_from_memory( audio_timebase_denominator, ) - vframes, _vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result if aframes.numel() > 0: # when audio stream is found @@ -413,10 +400,7 @@ def _read_video_timestamps_from_memory(video_data): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, vduration, \ - _aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() @@ -439,10 +423,10 @@ def _probe_video_from_memory(video_data): def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): - if pts_unit == 'pts': + if pts_unit == "pts": start_pts = float(start_pts * time_base) end_pts = float(end_pts * time_base) - pts_unit = 'sec' + pts_unit = "sec" return start_pts, end_pts, pts_unit @@ -467,20 +451,15 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): time_base = default_timebase if has_video: - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) time_base = video_timebase if has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) time_base = time_base if time_base else audio_timebase # video_timebase is the default time_base - start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec( - start_pts, end_pts, pts_unit, time_base) + start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): start_offset = start_pts_sec @@ -527,9 +506,7 @@ def _read_video_timestamps(filename, pts_unit="pts"): pts, _, info = _read_video_timestamps_from_file(filename) if pts_unit == "sec": - video_time_base = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) pts = [x * video_time_base for x in pts] video_fps = info.video_fps if info.has_video else None diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 343c0b3a33d..1eb1621f779 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,11 +1,11 @@ -import torch from enum import Enum -from .._internally_replaced_utils import _get_extension_path +import torch +from .._internally_replaced_utils import _get_extension_path try: - lib_path = _get_extension_path('image') + lib_path = _get_extension_path("image") torch.ops.load_library(lib_path) except (ImportError, OSError): pass @@ -21,6 +21,7 @@ class ImageReadMode(Enum): ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for RGB with transparency. """ + UNCHANGED = 0 GRAY = 1 GRAY_ALPHA = 2 @@ -111,8 +112,9 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, - device: str = 'cpu') -> torch.Tensor: +def decode_jpeg( + input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu" +) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -135,7 +137,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG output (Tensor[image_channels, image_height, image_width]) """ device = torch.device(device) - if device.type == 'cuda': + if device.type == "cuda": output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) else: output = torch.ops.image.decode_jpeg(input, mode.value) @@ -158,8 +160,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: JPEG file. """ if quality < 1 or quality > 100: - raise ValueError('Image quality should be a positive number ' - 'between 1 and 100') + raise ValueError("Image quality should be a positive number " "between 1 and 100") output = torch.ops.image.encode_jpeg(input, quality) return output diff --git a/torchvision/io/video.py b/torchvision/io/video.py index deab193eda3..e6c4371e684 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -7,11 +7,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np + import torch from . import _video_opt - try: import av @@ -94,16 +94,16 @@ def write_video( if audio_array is not None: audio_format_dtypes = { - 'dbl': ' 0 and start_offset > 0 and start_offset not in frames: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that @@ -264,7 +260,7 @@ def read_video( from torchvision import get_video_backend if not os.path.exists(filename): - raise RuntimeError(f'File not found: {filename}') + raise RuntimeError(f"File not found: {filename}") if get_video_backend() != "pyav": return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) @@ -276,8 +272,7 @@ def read_video( if end_pts < start_pts: raise ValueError( - "end_pts should be larger than start_pts, got " - "start_pts={} and end_pts={}".format(start_pts, end_pts) + "end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts) ) info = {} @@ -292,8 +287,7 @@ def read_video( elif container.streams.audio: time_base = container.streams.audio[0].time_base # video_timebase is the default time_base - start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec( - start_pts, end_pts, pts_unit, time_base) + start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 283e544e98e..bd7792a02ff 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,14 +1,11 @@ +from . import detection, quantization, segmentation, video from .alexnet import * -from .resnet import * -from .vgg import * -from .squeezenet import * -from .inception import * from .densenet import * from .googlenet import * -from .mobilenet import * +from .inception import * from .mnasnet import * +from .mobilenet import * +from .resnet import * from .shufflenetv2 import * -from . import segmentation -from . import detection -from . import video -from . import quantization +from .squeezenet import * +from .vgg import * diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index df5ab9a044c..496a6878eb4 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,7 @@ from collections import OrderedDict +from typing import Dict from torch import nn -from typing import Dict class IntermediateLayerGetter(nn.ModuleDict): @@ -35,6 +35,7 @@ class IntermediateLayerGetter(nn.ModuleDict): >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] """ + _version = 2 __annotations__ = { "return_layers": Dict[str, str], diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 156a453c3cc..0a3523e8278 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,19 +1,19 @@ +from typing import Any + import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['AlexNet', 'alexnet'] +__all__ = ["AlexNet", "alexnet"] model_urls = { - 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-7be5be79.pth', + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", } class AlexNet(nn.Module): - def __init__(self, num_classes: int = 1000) -> None: super(AlexNet, self).__init__() self.features = nn.Sequential( @@ -61,7 +61,6 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A """ model = AlexNet(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['alexnet'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index aef7977773b..f341851475c 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,50 +1,46 @@ import re +from collections import OrderedDict +from typing import Any, List, Tuple + import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp -from collections import OrderedDict -from .._internally_replaced_utils import load_state_dict_from_url from torch import Tensor -from typing import Any, List, Tuple +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] +__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", } class _DenseLayer(nn.Module): def __init__( - self, - num_input_features: int, - growth_rate: int, - bn_size: int, - drop_rate: float, - memory_efficient: bool = False + self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False ) -> None: super(_DenseLayer, self).__init__() self.norm1: nn.BatchNorm2d - self.add_module('norm1', nn.BatchNorm2d(num_input_features)) + self.add_module("norm1", nn.BatchNorm2d(num_input_features)) self.relu1: nn.ReLU - self.add_module('relu1', nn.ReLU(inplace=True)) + self.add_module("relu1", nn.ReLU(inplace=True)) self.conv1: nn.Conv2d - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, - bias=False)) + self.add_module( + "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) + ) self.norm2: nn.BatchNorm2d - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) + self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)) self.relu2: nn.ReLU - self.add_module('relu2', nn.ReLU(inplace=True)) + self.add_module("relu2", nn.ReLU(inplace=True)) self.conv2: nn.Conv2d - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, - bias=False)) + self.add_module( + "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) + ) self.drop_rate = float(drop_rate) self.memory_efficient = memory_efficient @@ -93,8 +89,7 @@ def forward(self, input: Tensor) -> Tensor: # noqa: F811 new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) if self.drop_rate > 0: - new_features = F.dropout(new_features, p=self.drop_rate, - training=self.training) + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return new_features @@ -108,7 +103,7 @@ def __init__( bn_size: int, growth_rate: int, drop_rate: float, - memory_efficient: bool = False + memory_efficient: bool = False, ) -> None: super(_DenseBlock, self).__init__() for i in range(num_layers): @@ -119,7 +114,7 @@ def __init__( drop_rate=drop_rate, memory_efficient=memory_efficient, ) - self.add_module('denselayer%d' % (i + 1), layer) + self.add_module("denselayer%d" % (i + 1), layer) def forward(self, init_features: Tensor) -> Tensor: features = [init_features] @@ -132,11 +127,10 @@ def forward(self, init_features: Tensor) -> Tensor: class _Transition(nn.Sequential): def __init__(self, num_input_features: int, num_output_features: int) -> None: super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + self.add_module("norm", nn.BatchNorm2d(num_input_features)) + self.add_module("relu", nn.ReLU(inplace=True)) + self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): @@ -163,19 +157,22 @@ def __init__( bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000, - memory_efficient: bool = False + memory_efficient: bool = False, ) -> None: super(DenseNet, self).__init__() # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, - padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) + self.features = nn.Sequential( + OrderedDict( + [ + ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ("norm0", nn.BatchNorm2d(num_init_features)), + ("relu0", nn.ReLU(inplace=True)), + ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ] + ) + ) # Each denseblock num_features = num_init_features @@ -186,18 +183,17 @@ def __init__( bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, - memory_efficient=memory_efficient + memory_efficient=memory_efficient, ) - self.features.add_module('denseblock%d' % (i + 1), block) + self.features.add_module("denseblock%d" % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, - num_output_features=num_features // 2) - self.features.add_module('transition%d' % (i + 1), trans) + trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) + self.features.add_module("transition%d" % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + self.features.add_module("norm5", nn.BatchNorm2d(num_features)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) @@ -227,7 +223,8 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): @@ -246,7 +243,7 @@ def _densenet( num_init_features: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ) -> DenseNet: model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) if pretrained: @@ -265,8 +262,7 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, - **kwargs) + return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -280,8 +276,7 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, - **kwargs) + return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -295,8 +290,7 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, - **kwargs) + return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -310,5 +304,4 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, - **kwargs) + return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 4772415b3b1..13edbf75575 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,6 +1,6 @@ from .faster_rcnn import * -from .mask_rcnn import * from .keypoint_rcnn import * +from .mask_rcnn import * from .retinanet import * from .ssd import * from .ssdlite import * diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 40281b39b6b..d14d22d8a09 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -1,10 +1,9 @@ import math -import torch - from collections import OrderedDict -from torch import Tensor from typing import List, Tuple +import torch +from torch import Tensor from torchvision.ops.misc import FrozenBatchNorm2d @@ -61,12 +60,8 @@ def __call__(self, matched_idxs): neg_idx_per_image = negative[perm2] # create binary mask from indices - pos_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) - neg_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) + pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) + neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) pos_idx_per_image_mask[pos_idx_per_image] = 1 neg_idx_per_image_mask[neg_idx_per_image] = 1 @@ -132,7 +127,7 @@ class BoxCoder(object): the representation used for training the regressors. """ - def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): + def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): # type: (Tuple[float, float, float, float], float) -> None """ Args: @@ -177,9 +172,7 @@ def decode(self, rel_codes, boxes): box_sum += val if box_sum > 0: rel_codes = rel_codes.reshape(box_sum, -1) - pred_boxes = self.decode_single( - rel_codes, concat_boxes - ) + pred_boxes = self.decode_single(rel_codes, concat_boxes) if box_sum > 0: pred_boxes = pred_boxes.reshape(box_sum, -1, 4) return pred_boxes @@ -243,8 +236,8 @@ class Matcher(object): BETWEEN_THRESHOLDS = -2 __annotations__ = { - 'BELOW_LOW_THRESHOLD': int, - 'BETWEEN_THRESHOLDS': int, + "BELOW_LOW_THRESHOLD": int, + "BETWEEN_THRESHOLDS": int, } def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): @@ -283,13 +276,9 @@ def __call__(self, match_quality_matrix): if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: - raise ValueError( - "No ground-truth boxes available for one of the images " - "during training") + raise ValueError("No ground-truth boxes available for one of the images " "during training") else: - raise ValueError( - "No proposal boxes available for one of the images " - "during training") + raise ValueError("No proposal boxes available for one of the images " "during training") # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction @@ -301,9 +290,7 @@ def __call__(self, match_quality_matrix): # Assign candidate matches with low quality to negative (unassigned) values below_low_threshold = matched_vals < self.low_threshold - between_thresholds = (matched_vals >= self.low_threshold) & ( - matched_vals < self.high_threshold - ) + between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold) matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD matches[between_thresholds] = self.BETWEEN_THRESHOLDS @@ -324,9 +311,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): # For each gt, find the prediction with which it has highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find highest quality match available, even if it is low, including ties - gt_pred_pairs_of_highest_quality = torch.where( - match_quality_matrix == highest_quality_foreach_gt[:, None] - ) + gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) # Example gt_pred_pairs_of_highest_quality: # tensor([[ 0, 39796], # [ 1, 32055], @@ -346,7 +331,6 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): class SSDMatcher(Matcher): - def __init__(self, threshold): super().__init__(threshold, threshold, allow_low_quality_matches=False) @@ -355,9 +339,9 @@ def __call__(self, match_quality_matrix): # For each gt, find the prediction with which it has the highest quality _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) - matches[highest_quality_pred_foreach_gt] = torch.arange(highest_quality_pred_foreach_gt.size(0), - dtype=torch.int64, - device=highest_quality_pred_foreach_gt.device) + matches[highest_quality_pred_foreach_gt] = torch.arange( + highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device + ) return matches @@ -401,7 +385,7 @@ def retrieve_out_channels(model, size): tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device) features = model(tmp_img) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) out_channels = [x.size(1) for x in features.values()] if in_training: diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 06ecc551442..6c073077546 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,8 +1,9 @@ import math +from typing import List, Optional + import torch -from torch import nn, Tensor +from torch import Tensor, nn -from typing import List, Optional from .image_list import ImageList @@ -48,15 +49,21 @@ def __init__( self.sizes = sizes self.aspect_ratios = aspect_ratios - self.cell_anchors = [self.generate_anchors(size, aspect_ratio) - for size, aspect_ratio in zip(sizes, aspect_ratios)] + self.cell_anchors = [ + self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios) + ] # TODO: https://github.com/pytorch/pytorch/issues/26792 # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # This method assumes aspect ratio = height / width for an anchor. - def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu")): + def generate_anchors( + self, + scales: List[int], + aspect_ratios: List[float], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): scales = torch.as_tensor(scales, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) h_ratios = torch.sqrt(aspect_ratios) @@ -69,8 +76,7 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: return base_anchors.round() def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): - self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) - for cell_anchor in self.cell_anchors] + self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] @@ -83,25 +89,21 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) assert cell_anchors is not None if not (len(grid_sizes) == len(strides) == len(cell_anchors)): - raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature " - "map could potentially have different sizes and aspect ratios. " - "There needs to be a match between the number of " - "feature maps passed and the number of sizes / aspect ratios specified.") - - for size, stride, base_anchors in zip( - grid_sizes, strides, cell_anchors - ): + raise ValueError( + "Anchors should be Tuple[Tuple[int]] because each feature " + "map could potentially have different sizes and aspect ratios. " + "There needs to be a match between the number of " + "feature maps passed and the number of sizes / aspect ratios specified." + ) + + for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): grid_height, grid_width = size stride_height, stride_width = stride device = base_anchors.device # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange( - 0, grid_width, dtype=torch.float32, device=device - ) * stride_width - shifts_y = torch.arange( - 0, grid_height, dtype=torch.float32, device=device - ) * stride_height + shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device) * stride_width + shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) @@ -109,9 +111,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) # For every (base anchor, output anchor) pair, # offset each zero-centered base anchor by the center of the output anchor. - anchors.append( - (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) - ) + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) return anchors @@ -119,8 +119,13 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] image_size = image_list.tensors.shape[-2:] dtype, device = feature_maps[0].dtype, feature_maps[0].device - strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), - torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] + strides = [ + [ + torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), + torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device), + ] + for g in grid_sizes + ] self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors: List[List[torch.Tensor]] = [] @@ -149,8 +154,15 @@ class DefaultBoxGenerator(nn.Module): is applied while the boxes are encoded in format ``(cx, cy, w, h)``. """ - def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9, - scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True): + def __init__( + self, + aspect_ratios: List[List[int]], + min_ratio: float = 0.15, + max_ratio: float = 0.9, + scales: Optional[List[float]] = None, + steps: Optional[List[int]] = None, + clip: bool = True, + ): super().__init__() if steps is not None: assert len(aspect_ratios) == len(steps) @@ -172,8 +184,9 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ self._wh_pairs = self._generate_wh_pairs(num_outputs) - def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu")) -> List[Tensor]: + def _generate_wh_pairs( + self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu") + ) -> List[Tensor]: _wh_pairs: List[Tensor] = [] for k in range(num_outputs): # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k @@ -196,8 +209,9 @@ def num_anchors_per_location(self): return [2 + 2 * len(r) for r in self.aspect_ratios] # Default Boxes calculation based on page 6 of SSD paper - def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int], - dtype: torch.dtype = torch.float32) -> Tensor: + def _grid_default_boxes( + self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32 + ) -> Tensor: default_boxes = [] for k, f_k in enumerate(grid_sizes): # Now add the default boxes for each width-height pair @@ -224,12 +238,12 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int] return torch.cat(default_boxes, dim=0) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'aspect_ratios={aspect_ratios}' - s += ', clip={clip}' - s += ', scales={scales}' - s += ', steps={steps}' - s += ')' + s = self.__class__.__name__ + "(" + s += "aspect_ratios={aspect_ratios}" + s += ", clip={clip}" + s += ", scales={scales}" + s += ", steps={steps}" + s += ")" return s.format(**self.__dict__) def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: @@ -242,8 +256,13 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten dboxes = [] for _ in image_list.image_sizes: dboxes_in_image = default_boxes - dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], - dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1) + dboxes_in_image = torch.cat( + [ + dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], + dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:], + ], + -1, + ) dboxes_in_image[:, 0::2] *= image_size[1] dboxes_in_image[:, 1::2] *= image_size[0] dboxes.append(dboxes_in_image) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 3178a81b52c..6b169de3d07 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,11 +1,11 @@ import warnings + from torch import nn +from torchvision.ops import misc as misc_nn_ops from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool -from torchvision.ops import misc as misc_nn_ops +from .. import mobilenet, resnet from .._utils import IntermediateLayerGetter -from .. import mobilenet -from .. import resnet class BackboneWithFPN(nn.Module): @@ -26,6 +26,7 @@ class BackboneWithFPN(nn.Module): Attributes: out_channels (int): the number of channels in the FPN """ + def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): super(BackboneWithFPN, self).__init__() @@ -52,7 +53,7 @@ def resnet_fpn_backbone( norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3, returned_layers=None, - extra_blocks=None + extra_blocks=None, ): """ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. @@ -89,15 +90,13 @@ def resnet_fpn_backbone( a new list of feature maps and their corresponding names. By default a ``LastLevelMaxPool`` is used. """ - backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained, - norm_layer=norm_layer) + backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) # select layers that wont be frozen assert 0 <= trainable_layers <= 5 - layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] + layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] if trainable_layers == 5: - layers_to_train.append('bn1') + layers_to_train.append("bn1") for name, parameter in backbone.named_parameters(): if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) @@ -108,7 +107,7 @@ def resnet_fpn_backbone( if returned_layers is None: returned_layers = [1, 2, 3, 4] assert min(returned_layers) > 0 and max(returned_layers) < 5 - return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} + return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} in_channels_stage2 = backbone.inplanes // 8 in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] @@ -123,7 +122,8 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, warnings.warn( "Changing trainable_backbone_layers has not effect if " "neither pretrained nor pretrained_backbone have been set to True, " - "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) + "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value) + ) trainable_backbone_layers = max_value # by default freeze first blocks @@ -140,7 +140,7 @@ def mobilenet_backbone( norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=2, returned_layers=None, - extra_blocks=None + extra_blocks=None, ): backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features @@ -165,7 +165,7 @@ def mobilenet_backbone( if returned_layers is None: returned_layers = [num_stages - 2, num_stages - 1] assert min(returned_layers) >= 0 and max(returned_layers) < num_stages - return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)} + return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 83f2eb88f88..5f66f2cb667 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,22 +1,21 @@ -from torch import nn import torch.nn.functional as F - +from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator +from .backbone_utils import _validate_trainable_layers, mobilenet_backbone, resnet_fpn_backbone from .generalized_rcnn import GeneralizedRCNN -from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads +from .rpn import RegionProposalNetwork, RPNHead from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone - __all__ = [ - "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", - "fasterrcnn_mobilenet_v3_large_fpn" + "FasterRCNN", + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", ] @@ -141,30 +140,48 @@ class FasterRCNN(GeneralizedRCNN): >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None): + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + ): if not hasattr(backbone, "out_channels"): raise ValueError( "backbone should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " - "same for all the levels)") + "same for all the levels)" + ) assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) @@ -174,58 +191,59 @@ def __init__(self, backbone, num_classes=None, raise ValueError("num_classes should be None when box_predictor is specified") else: if box_predictor is None: - raise ValueError("num_classes should not be None when box_predictor " - "is not specified") + raise ValueError("num_classes should not be None when box_predictor " "is not specified") out_channels = backbone.out_channels if rpn_anchor_generator is None: anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) if rpn_head is None: - rpn_head = RPNHead( - out_channels, rpn_anchor_generator.num_anchors_per_location()[0] - ) + rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) rpn = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, - score_thresh=rpn_score_thresh) + rpn_anchor_generator, + rpn_head, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_nms_thresh, + score_thresh=rpn_score_thresh, + ) if box_roi_pool is None: - box_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) if box_head is None: resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - out_channels * resolution ** 2, - representation_size) + box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) if box_predictor is None: representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - num_classes) + box_predictor = FastRCNNPredictor(representation_size, num_classes) roi_heads = RoIHeads( # Box - box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, + box_roi_pool, + box_head, + box_predictor, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + ) if image_mean is None: image_mean = [0.485, 0.456, 0.406] @@ -286,17 +304,15 @@ def forward(self, x): model_urls = { - 'fasterrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', - 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', - 'fasterrcnn_mobilenet_v3_large_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' + "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", # noqa: E501 + "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", # noqa: E501 + "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", # noqa: E501 } -def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def fasterrcnn_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -359,36 +375,54 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model -def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91, - pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=False, + progress=True, + num_classes=91, + pretrained_backbone=True, + trainable_backbone_layers=None, + **kwargs, +): trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 + ) if pretrained: pretrained_backbone = False - backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, - trainable_layers=trainable_backbone_layers) - - anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 + backbone = mobilenet_backbone( + "mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers + ) + + anchor_sizes = ( + ( + 32, + 64, + 128, + 256, + 512, + ), + ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - **kwargs) + model = FasterRCNN( + backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs + ) if pretrained: if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) @@ -397,8 +431,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress= return model -def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): +def fasterrcnn_mobilenet_v3_large_320_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -430,13 +465,20 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c } kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, - num_classes=num_classes, pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, **kwargs) - - -def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): + return _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=pretrained, + progress=progress, + num_classes=num_classes, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) + + +def fasterrcnn_mobilenet_v3_large_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -464,6 +506,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class } kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, - num_classes=num_classes, pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, **kwargs) + return _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=pretrained, + progress=progress, + num_classes=num_classes, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 1d3979caa3f..93f50648517 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -2,11 +2,12 @@ Implements the Generalized R-CNN framework """ +import warnings from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + import torch -from torch import nn, Tensor -import warnings -from typing import Tuple, List, Dict, Optional, Union +from torch import Tensor, nn class GeneralizedRCNN(nn.Module): @@ -61,12 +62,11 @@ def forward(self, images, targets=None): boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) original_image_sizes: List[Tuple[int, int]] = [] for img in images: @@ -86,13 +86,14 @@ def forward(self, images, targets=None): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) proposals, proposal_losses = self.rpn(images, features, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index a389b3c3ce1..333d3b569f2 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -1,6 +1,7 @@ +from typing import List, Tuple + import torch from torch import Tensor -from typing import List, Tuple class ImageList(object): @@ -20,6 +21,6 @@ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]): self.tensors = tensors self.image_sizes = image_sizes - def to(self, device: torch.device) -> 'ImageList': + def to(self, device: torch.device) -> "ImageList": cast_tensor = self.tensors.to(device) return ImageList(cast_tensor, self.image_sizes) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index fd9a980b97d..dad1b2e2a5f 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,18 +1,13 @@ import torch from torch import nn - from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ._utils import overwrite_eps +from .backbone_utils import _validate_trainable_layers, resnet_fpn_backbone from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers - -__all__ = [ - "KeypointRCNN", "keypointrcnn_resnet50_fpn" -] +__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"] class KeypointRCNN(FasterRCNN): @@ -151,27 +146,47 @@ class KeypointRCNN(FasterRCNN): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=None, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None, - # keypoint parameters - keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, - num_keypoints=17): + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=None, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # keypoint parameters + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + num_keypoints=17, + ): assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: @@ -184,10 +199,7 @@ def __init__(self, backbone, num_classes=None, out_channels = backbone.out_channels if keypoint_roi_pool is None: - keypoint_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=14, - sampling_ratio=2) + keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) if keypoint_head is None: keypoint_layers = tuple(512 for _ in range(8)) @@ -198,24 +210,39 @@ def __init__(self, backbone, num_classes=None, keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) super(KeypointRCNN, self).__init__( - backbone, num_classes, + backbone, + num_classes, # transform parameters - min_size, max_size, - image_mean, image_std, + min_size, + max_size, + image_mean, + image_std, # RPN-specific parameters - rpn_anchor_generator, rpn_head, - rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, - rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, rpn_nms_thresh, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, rpn_score_thresh, # Box parameters - box_roi_pool, box_head, box_predictor, - box_score_thresh, box_nms_thresh, box_detections_per_img, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights) + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + ) self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_head = keypoint_head @@ -249,9 +276,7 @@ def __init__(self, in_channels, num_keypoints): stride=2, padding=deconv_kernel // 2 - 1, ) - nn.init.kaiming_normal_( - self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" - ) + nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = num_keypoints @@ -265,16 +290,20 @@ def forward(self, x): model_urls = { # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 - 'keypointrcnn_resnet50_fpn_coco_legacy': - 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth', - 'keypointrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth', + "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", # noqa: E501 + "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", } -def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=2, num_keypoints=17, - pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def keypointrcnn_resnet50_fpn( + pretrained=False, + progress=True, + num_classes=2, + num_keypoints=17, + pretrained_backbone=True, + trainable_backbone_layers=None, + **kwargs, +): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -329,19 +358,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: - key = 'keypointrcnn_resnet50_fpn_coco' - if pretrained == 'legacy': - key += '_legacy' - state_dict = load_state_dict_from_url(model_urls[key], - progress=progress) + key = "keypointrcnn_resnet50_fpn_coco" + if pretrained == "legacy": + key += "_legacy" + state_dict = load_state_dict_from_url(model_urls[key], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index ad8f356ad69..d49ee634189 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,17 +1,16 @@ from collections import OrderedDict from torch import nn - from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ._utils import overwrite_eps +from .backbone_utils import _validate_trainable_layers, resnet_fpn_backbone from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ - "MaskRCNN", "maskrcnn_resnet50_fpn", + "MaskRCNN", + "maskrcnn_resnet50_fpn", ] @@ -149,26 +148,46 @@ class MaskRCNN(FasterRCNN): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None, - # Mask parameters - mask_roi_pool=None, mask_head=None, mask_predictor=None): + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # Mask parameters + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + ): assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) @@ -179,10 +198,7 @@ def __init__(self, backbone, num_classes=None, out_channels = backbone.out_channels if mask_roi_pool is None: - mask_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=14, - sampling_ratio=2) + mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) if mask_head is None: mask_layers = (256, 256, 256, 256) @@ -192,28 +208,42 @@ def __init__(self, backbone, num_classes=None, if mask_predictor is None: mask_predictor_in_channels = 256 # == mask_layers[-1] mask_dim_reduced = 256 - mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, - mask_dim_reduced, num_classes) + mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes) super(MaskRCNN, self).__init__( - backbone, num_classes, + backbone, + num_classes, # transform parameters - min_size, max_size, - image_mean, image_std, + min_size, + max_size, + image_mean, + image_std, # RPN-specific parameters - rpn_anchor_generator, rpn_head, - rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, - rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, rpn_nms_thresh, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, rpn_score_thresh, # Box parameters - box_roi_pool, box_head, box_predictor, - box_score_thresh, box_nms_thresh, box_detections_per_img, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights) + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + ) self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_head = mask_head @@ -232,8 +262,8 @@ def __init__(self, in_channels, layers, dilation): next_feature = in_channels for layer_idx, layer_features in enumerate(layers, 1): d["mask_fcn{}".format(layer_idx)] = nn.Conv2d( - next_feature, layer_features, kernel_size=3, - stride=1, padding=dilation, dilation=dilation) + next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation + ) d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) next_feature = layer_features @@ -247,11 +277,15 @@ def __init__(self, in_channels, layers, dilation): class MaskRCNNPredictor(nn.Sequential): def __init__(self, in_channels, dim_reduced, num_classes): - super(MaskRCNNPredictor, self).__init__(OrderedDict([ - ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), - ("relu", nn.ReLU(inplace=True)), - ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), - ])) + super(MaskRCNNPredictor, self).__init__( + OrderedDict( + [ + ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), + ("relu", nn.ReLU(inplace=True)), + ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), + ] + ) + ) for name, param in self.named_parameters(): if "weight" in name: @@ -261,13 +295,13 @@ def __init__(self, in_channels, dim_reduced, num_classes): model_urls = { - 'maskrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth', + "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", } -def maskrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def maskrcnn_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -322,16 +356,16 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4dd95285dbc..faa2817d5dd 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,26 +1,21 @@ import math -from collections import OrderedDict import warnings +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple import torch -from torch import nn, Tensor -from typing import Dict, List, Tuple, Optional +from torch import Tensor, nn -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ...ops import boxes as box_ops, sigmoid_focal_loss +from ...ops.feature_pyramid_network import LastLevelP6P7 from . import _utils as det_utils +from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator +from .backbone_utils import _validate_trainable_layers, resnet_fpn_backbone from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers -from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...ops import sigmoid_focal_loss -from ...ops import boxes as box_ops - -__all__ = [ - "RetinaNet", "retinanet_resnet50_fpn" -] +__all__ = ["RetinaNet", "retinanet_resnet50_fpn"] def _sum(x: List[Tensor]) -> Tensor: @@ -48,16 +43,13 @@ def __init__(self, in_channels, num_anchors, num_classes): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] return { - 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), - 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), + "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } def forward(self, x): # type: (List[Tensor]) -> Dict[str, Tensor] - return { - 'cls_logits': self.classification_head(x), - 'bbox_regression': self.regression_head(x) - } + return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} class RetinaNetClassificationHead(nn.Module): @@ -100,7 +92,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor losses = [] - cls_logits = head_outputs['cls_logits'] + cls_logits = head_outputs["cls_logits"] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # determine only the foreground @@ -111,18 +103,21 @@ def compute_loss(self, targets, head_outputs, matched_idxs): gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[ foreground_idxs_per_image, - targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], ] = 1.0 # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss - losses.append(sigmoid_focal_loss( - cls_logits_per_image[valid_idxs_per_image], - gt_classes_target[valid_idxs_per_image], - reduction='sum', - ) / max(1, num_foreground)) + losses.append( + sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction="sum", + ) + / max(1, num_foreground) + ) return _sum(losses) / len(targets) @@ -153,8 +148,9 @@ class RetinaNetRegressionHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, + "box_coder": det_utils.BoxCoder, } def __init__(self, in_channels, num_anchors): @@ -181,16 +177,17 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor losses = [] - bbox_regression = head_outputs['bbox_regression'] + bbox_regression = head_outputs["bbox_regression"] - for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ - zip(targets, bbox_regression, anchors, matched_idxs): + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip( + targets, bbox_regression, anchors, matched_idxs + ): # determine only the foreground indices, ignore the rest foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] num_foreground = foreground_idxs_per_image.numel() # select only the foreground boxes - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] + matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] @@ -198,11 +195,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - losses.append(torch.nn.functional.l1_loss( - bbox_regression_per_image, - target_regression, - reduction='sum' - ) / max(1, num_foreground)) + losses.append( + torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + / max(1, num_foreground) + ) return _sum(losses) / max(1, len(targets)) @@ -309,30 +305,40 @@ class RetinaNet(nn.Module): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, } - def __init__(self, backbone, num_classes, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # Anchor parameters - anchor_generator=None, head=None, - proposal_matcher=None, - score_thresh=0.05, - nms_thresh=0.5, - detections_per_img=300, - fg_iou_thresh=0.5, bg_iou_thresh=0.4, - topk_candidates=1000): + def __init__( + self, + backbone, + num_classes, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # Anchor parameters + anchor_generator=None, + head=None, + proposal_matcher=None, + score_thresh=0.05, + nms_thresh=0.5, + detections_per_img=300, + fg_iou_thresh=0.5, + bg_iou_thresh=0.4, + topk_candidates=1000, + ): super().__init__() if not hasattr(backbone, "out_channels"): raise ValueError( "backbone should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " - "same for all the levels)") + "same for all the levels)" + ) self.backbone = backbone assert isinstance(anchor_generator, (AnchorGenerator, type(None))) @@ -340,9 +346,7 @@ def __init__(self, backbone, num_classes, if anchor_generator is None: anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) self.anchor_generator = anchor_generator if head is None: @@ -385,20 +389,21 @@ def compute_loss(self, targets, head_outputs, anchors): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, - device=anchors_per_image.device)) + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) continue - match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) matched_idxs.append(self.proposal_matcher(match_quality_matrix)) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections(self, head_outputs, anchors, image_shapes): # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] - class_logits = head_outputs['cls_logits'] - box_regression = head_outputs['bbox_regression'] + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] num_images = len(image_shapes) @@ -413,8 +418,9 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): image_scores = [] image_labels = [] - for box_regression_per_level, logits_per_level, anchors_per_level in \ - zip(box_regression_per_image, logits_per_image, anchors_per_image): + for box_regression_per_level, logits_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, anchors_per_image + ): num_classes = logits_per_level.shape[-1] # remove low scoring boxes @@ -428,11 +434,12 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] - anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode='floor') + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") labels_per_level = topk_idxs % num_classes - boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs], - anchors_per_level[anchor_idxs]) + boxes_per_level = self.box_coder.decode_single( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) image_boxes.append(boxes_per_level) @@ -445,13 +452,15 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): # non-maximum suppression keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) - keep = keep[:self.detections_per_img] - - detections.append({ - 'boxes': image_boxes[keep], - 'scores': image_scores[keep], - 'labels': image_labels[keep], - }) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) return detections @@ -478,12 +487,11 @@ def forward(self, images, targets=None): boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] @@ -505,14 +513,15 @@ def forward(self, images, targets=None): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) # get the features from the backbone features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) # TODO: Do we want a list or a dict? features = list(features.values()) @@ -536,7 +545,7 @@ def forward(self, images, targets=None): HW = 0 for v in num_anchors_per_level: HW += v - HWA = head_outputs['cls_logits'].size(1) + HWA = head_outputs["cls_logits"].size(1) A = HWA // HW num_anchors_per_level = [hw * A for hw in num_anchors_per_level] @@ -559,13 +568,13 @@ def forward(self, images, targets=None): model_urls = { - 'retinanet_resnet50_fpn_coco': - 'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', + "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", } -def retinanet_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def retinanet_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. @@ -611,18 +620,23 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False # skip P2 because it generates too many anchors (according to their paper) - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, returned_layers=[2, 3, 4], - extra_blocks=LastLevelP6P7(256, 256), trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone( + "resnet50", + pretrained_backbone, + returned_layers=[2, 3, 4], + extra_blocks=LastLevelP6P7(256, 256), + trainable_layers=trainable_backbone_layers, + ) model = RetinaNet(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 9948d5f537f..18888add9cd 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -1,17 +1,13 @@ -import torch -import torchvision +from typing import Dict, List, Optional, Tuple +import torch import torch.nn.functional as F -from torch import nn, Tensor - -from torchvision.ops import boxes as box_ops - -from torchvision.ops import roi_align +import torchvision +from torch import Tensor, nn +from torchvision.ops import boxes as box_ops, roi_align from . import _utils as det_utils -from typing import Optional, List, Dict, Tuple - def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] @@ -46,7 +42,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): box_regression[sampled_pos_inds_subset, labels_pos], regression_targets[sampled_pos_inds_subset], beta=1 / 9, - reduction='sum', + reduction="sum", ) box_loss = box_loss / labels.numel() @@ -95,7 +91,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): matched_idxs = matched_idxs.to(boxes) rois = torch.cat([matched_idxs[:, None], boxes], dim=1) gt_masks = gt_masks[:, None].to(rois) - return roi_align(gt_masks, rois, (M, M), 1.)[:, 0] + return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0] def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): @@ -113,8 +109,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs discretization_size = mask_logits.shape[-1] labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)] mask_targets = [ - project_masks_on_boxes(m, p, i, discretization_size) - for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) + project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) ] labels = torch.cat(labels, dim=0) @@ -167,59 +162,72 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): return heatmaps, valid -def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, - widths_i, heights_i, offset_x_i, offset_y_i): +def _onnx_heatmaps_to_keypoints( + maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i +): num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) width_correction = widths_i / roi_map_width height_correction = heights_i / roi_map_height roi_map = F.interpolate( - maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0] + maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False + )[:, 0] w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) - x_int = (pos % w) - y_int = ((pos - x_int) // w) + x_int = pos % w + y_int = (pos - x_int) // w - x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \ - width_correction.to(dtype=torch.float32) - y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \ - height_correction.to(dtype=torch.float32) + x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to( + dtype=torch.float32 + ) + y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to( + dtype=torch.float32 + ) xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32) xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32) xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32) - xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32), - xy_preds_i_1.to(dtype=torch.float32), - xy_preds_i_2.to(dtype=torch.float32)], 0) + xy_preds_i = torch.stack( + [ + xy_preds_i_0.to(dtype=torch.float32), + xy_preds_i_1.to(dtype=torch.float32), + xy_preds_i_2.to(dtype=torch.float32), + ], + 0, + ) # TODO: simplify when indexing without rank will be supported by ONNX base = num_keypoints * num_keypoints + num_keypoints + 1 ind = torch.arange(num_keypoints) ind = ind.to(dtype=torch.int64) * base - end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \ - .index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64)) + end_scores_i = ( + roi_map.index_select(1, y_int.to(dtype=torch.int64)) + .index_select(2, x_int.to(dtype=torch.int64)) + .view(-1) + .index_select(0, ind.to(dtype=torch.int64)) + ) return xy_preds_i, end_scores_i @torch.jit._script_if_tracing -def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, - widths, heights, offset_x, offset_y, num_keypoints): +def _onnx_heatmaps_to_keypoints_loop( + maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints +): xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) for i in range(int(rois.size(0))): - xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i], - widths_ceil[i], heights_ceil[i], - widths[i], heights[i], - offset_x[i], offset_y[i]) - xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), - xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) - end_scores = torch.cat((end_scores.to(dtype=torch.float32), - end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0) + xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints( + maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i] + ) + xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) + end_scores = torch.cat( + (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0 + ) return xy_preds, end_scores @@ -246,10 +254,17 @@ def heatmaps_to_keypoints(maps, rois): num_keypoints = maps.shape[1] if torchvision._is_tracing(): - xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois, - widths_ceil, heights_ceil, widths, heights, - offset_x, offset_y, - torch.scalar_tensor(num_keypoints, dtype=torch.int64)) + xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop( + maps, + rois, + widths_ceil, + heights_ceil, + widths, + heights, + offset_x, + offset_y, + torch.scalar_tensor(num_keypoints, dtype=torch.int64), + ) return xy_preds.permute(0, 2, 1), end_scores xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) @@ -260,13 +275,14 @@ def heatmaps_to_keypoints(maps, rois): width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height roi_map = F.interpolate( - maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0] + maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False + )[:, 0] # roi_map_probs = scores_to_probs(roi_map.copy()) w = roi_map.shape[2] pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) x_int = pos % w - y_int = torch.div(pos - x_int, w, rounding_mode='floor') + y_int = torch.div(pos - x_int, w, rounding_mode="floor") # assert (roi_map_probs[k, y_int, x_int] == # roi_map_probs[k, :, :].max()) x = (x_int.float() + 0.5) * width_correction @@ -288,9 +304,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched valid = [] for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): kp = gt_kp_in_image[midx] - heatmaps_per_image, valid_per_image = keypoints_to_heatmap( - kp, proposals_per_image, discretization_size - ) + heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size) heatmaps.append(heatmaps_per_image.view(-1)) valid.append(valid_per_image.view(-1)) @@ -327,10 +341,10 @@ def keypointrcnn_inference(x, boxes): def _onnx_expand_boxes(boxes, scale): # type: (Tensor, float) -> Tensor - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 w_half = w_half.to(dtype=torch.float32) * scale h_half = h_half.to(dtype=torch.float32) * scale @@ -350,10 +364,10 @@ def expand_boxes(boxes, scale): # type: (Tensor, float) -> Tensor if torchvision._is_tracing(): return _onnx_expand_boxes(boxes, scale) - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 w_half *= scale h_half *= scale @@ -395,7 +409,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, -1, -1)) # Resize mask - mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False) mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) @@ -404,9 +418,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): y_0 = max(box[1], 0) y_1 = min(box[3] + 1, im_h) - im_mask[y_0:y_1, x_0:x_1] = mask[ - (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0]) - ] + im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] return im_mask @@ -414,8 +426,8 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): one = torch.ones(1, dtype=torch.int64) zero = torch.zeros(1, dtype=torch.int64) - w = (box[2] - box[0] + one) - h = (box[3] - box[1] + one) + w = box[2] - box[0] + one + h = box[3] - box[1] + one w = torch.max(torch.cat((w, one))) h = torch.max(torch.cat((h, one))) @@ -423,7 +435,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, mask.size(0), mask.size(1))) # Resize mask - mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False) mask = mask[0][0] x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) @@ -431,23 +443,18 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) - unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]), - (x_0 - box[0]):(x_1 - box[0])] + unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] # TODO : replace below with a dynamic padding when support is added in ONNX # pad y zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) - concat_0 = torch.cat((zeros_y0, - unpaded_im_mask.to(dtype=torch.float32), - zeros_y1), 0)[0:im_h, :] + concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :] # pad x zeros_x0 = torch.zeros(concat_0.size(0), x_0) zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) - im_mask = torch.cat((zeros_x0, - concat_0, - zeros_x1), 1)[:, :im_w] + im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w] return im_mask @@ -468,13 +475,10 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): im_h, im_w = img_shape if torchvision._is_tracing(): - return _onnx_paste_masks_in_image_loop(masks, boxes, - torch.scalar_tensor(im_h, dtype=torch.int64), - torch.scalar_tensor(im_w, dtype=torch.int64))[:, None] - res = [ - paste_mask_in_image(m[0], b, im_h, im_w) - for m, b in zip(masks, boxes) - ] + return _onnx_paste_masks_in_image_loop( + masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64) + )[:, None] + res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)] if len(res) > 0: ret = torch.stack(res, dim=0)[:, None] else: @@ -484,46 +488,44 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): class RoIHeads(nn.Module): __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, } - def __init__(self, - box_roi_pool, - box_head, - box_predictor, - # Faster R-CNN training - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - bbox_reg_weights, - # Faster R-CNN inference - score_thresh, - nms_thresh, - detections_per_img, - # Mask - mask_roi_pool=None, - mask_head=None, - mask_predictor=None, - keypoint_roi_pool=None, - keypoint_head=None, - keypoint_predictor=None, - ): + def __init__( + self, + box_roi_pool, + box_head, + box_predictor, + # Faster R-CNN training + fg_iou_thresh, + bg_iou_thresh, + batch_size_per_image, + positive_fraction, + bbox_reg_weights, + # Faster R-CNN inference + score_thresh, + nms_thresh, + detections_per_img, + # Mask + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + ): super(RoIHeads, self).__init__() self.box_similarity = box_ops.box_iou # assign ground-truth boxes for each proposal - self.proposal_matcher = det_utils.Matcher( - fg_iou_thresh, - bg_iou_thresh, - allow_low_quality_matches=False) + self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False) - self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( - batch_size_per_image, - positive_fraction) + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) if bbox_reg_weights is None: - bbox_reg_weights = (10., 10., 5., 5.) + bbox_reg_weights = (10.0, 10.0, 5.0, 5.0) self.box_coder = det_utils.BoxCoder(bbox_reg_weights) self.box_roi_pool = box_roi_pool @@ -572,9 +574,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): clamped_matched_idxs_in_image = torch.zeros( (proposals_in_image.shape[0],), dtype=torch.int64, device=device ) - labels_in_image = torch.zeros( - (proposals_in_image.shape[0],), dtype=torch.int64, device=device - ) + labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device) else: # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) @@ -601,19 +601,14 @@ def subsample(self, labels): # type: (List[Tensor]) -> List[Tensor] sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_inds = [] - for img_idx, (pos_inds_img, neg_inds_img) in enumerate( - zip(sampled_pos_inds, sampled_neg_inds) - ): + for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)): img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0] sampled_inds.append(img_sampled_inds) return sampled_inds def add_gt_proposals(self, proposals, gt_boxes): # type: (List[Tensor], List[Tensor]) -> List[Tensor] - proposals = [ - torch.cat((proposal, gt_box)) - for proposal, gt_box in zip(proposals, gt_boxes) - ] + proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)] return proposals @@ -625,10 +620,11 @@ def check_targets(self, targets): if self.has_mask(): assert all(["masks" in t for t in targets]) - def select_training_samples(self, - proposals, # type: List[Tensor] - targets # type: Optional[List[Dict[str, Tensor]]] - ): + def select_training_samples( + self, + proposals, # type: List[Tensor] + targets, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] self.check_targets(targets) assert targets is not None @@ -661,12 +657,13 @@ def select_training_samples(self, regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) return proposals, matched_idxs, labels, regression_targets - def postprocess_detections(self, - class_logits, # type: Tensor - box_regression, # type: Tensor - proposals, # type: List[Tensor] - image_shapes # type: List[Tuple[int, int]] - ): + def postprocess_detections( + self, + class_logits, # type: Tensor + box_regression, # type: Tensor + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + ): # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] device = class_logits.device num_classes = class_logits.shape[-1] @@ -710,7 +707,7 @@ def postprocess_detections(self, # non-maximum suppression, independently done per class keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) # keep only topk scoring predictions - keep = keep[:self.detections_per_img] + keep = keep[: self.detections_per_img] boxes, scores, labels = boxes[keep], scores[keep], labels[keep] all_boxes.append(boxes) @@ -719,12 +716,13 @@ def postprocess_detections(self, return all_boxes, all_scores, all_labels - def forward(self, - features, # type: Dict[str, Tensor] - proposals, # type: List[Tensor] - image_shapes, # type: List[Tuple[int, int]] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): + def forward( + self, + features, # type: Dict[str, Tensor] + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] """ Args: @@ -737,10 +735,10 @@ def forward(self, for t in targets: # TODO: https://github.com/pytorch/pytorch/issues/26731 floating_point_types = (torch.float, torch.double, torch.half) - assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' - assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' + assert t["boxes"].dtype in floating_point_types, "target boxes must of float type" + assert t["labels"].dtype == torch.int64, "target labels must of int64 type" if self.has_keypoint(): - assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' + assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type" if self.training: proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) @@ -757,12 +755,8 @@ def forward(self, losses = {} if self.training: assert labels is not None and regression_targets is not None - loss_classifier, loss_box_reg = fastrcnn_loss( - class_logits, box_regression, labels, regression_targets) - losses = { - "loss_classifier": loss_classifier, - "loss_box_reg": loss_box_reg - } + loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets) + losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg} else: boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) num_images = len(boxes) @@ -805,12 +799,8 @@ def forward(self, gt_masks = [t["masks"] for t in targets] gt_labels = [t["labels"] for t in targets] - rcnn_loss_mask = maskrcnn_loss( - mask_logits, mask_proposals, - gt_masks, gt_labels, pos_matched_idxs) - loss_mask = { - "loss_mask": rcnn_loss_mask - } + rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs) + loss_mask = {"loss_mask": rcnn_loss_mask} else: labels = [r["labels"] for r in result] masks_probs = maskrcnn_inference(mask_logits, labels) @@ -821,8 +811,11 @@ def forward(self, # keep none checks in if conditional so torchscript will conditionally # compile each branch - if self.keypoint_roi_pool is not None and self.keypoint_head is not None \ - and self.keypoint_predictor is not None: + if ( + self.keypoint_roi_pool is not None + and self.keypoint_head is not None + and self.keypoint_predictor is not None + ): keypoint_proposals = [p["boxes"] for p in result] if self.training: # during training, only focus on positive boxes @@ -848,11 +841,9 @@ def forward(self, gt_keypoints = [t["keypoints"] for t in targets] rcnn_loss_keypoint = keypointrcnn_loss( - keypoint_logits, keypoint_proposals, - gt_keypoints, pos_matched_idxs) - loss_keypoint = { - "loss_keypoint": rcnn_loss_keypoint - } + keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs + ) + loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint} else: assert keypoint_logits is not None assert keypoint_proposals is not None diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index a98eac24dd3..7c5fc66f133 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,27 +1,25 @@ -import torch -from torch.nn import functional as F -from torch import nn, Tensor +from typing import Dict, List, Optional, Tuple +import torch import torchvision +from torch import Tensor, nn +from torch.nn import functional as F from torchvision.ops import boxes as box_ops from . import _utils as det_utils -from .image_list import ImageList - -from typing import List, Optional, Dict, Tuple # Import AnchorGenerator to keep compatibility. -from .anchor_utils import AnchorGenerator +from .anchor_utils import AnchorGenerator # noqa: F401 +from .image_list import ImageList @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): # type: (Tensor, int) -> Tuple[int, int] from torch.onnx import operators + num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) - pre_nms_top_n = torch.min(torch.cat( - (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), - num_anchors), 0)) + pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) return num_anchors, pre_nms_top_n @@ -37,13 +35,9 @@ class RPNHead(nn.Module): def __init__(self, in_channels, num_anchors): super(RPNHead, self).__init__() - self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) - self.bbox_pred = nn.Conv2d( - in_channels, num_anchors * 4, kernel_size=1, stride=1 - ) + self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) for layer in self.children(): torch.nn.init.normal_(layer.weight, std=0.01) @@ -76,21 +70,15 @@ def concat_box_prediction_layers(box_cls, box_regression): # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression - for box_cls_per_level, box_regression_per_level in zip( - box_cls, box_regression - ): + for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression): N, AxC, H, W = box_cls_per_level.shape Ax4 = box_regression_per_level.shape[1] A = Ax4 // 4 C = AxC // A - box_cls_per_level = permute_and_flatten( - box_cls_per_level, N, A, C, H, W - ) + box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W) box_cls_flattened.append(box_cls_per_level) - box_regression_per_level = permute_and_flatten( - box_regression_per_level, N, A, 4, H, W - ) + box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W) box_regression_flattened.append(box_regression_per_level) # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps @@ -125,22 +113,30 @@ class RegionProposalNetwork(torch.nn.Module): nms_thresh (float): NMS threshold used for postprocessing the RPN proposals """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, - 'pre_nms_top_n': Dict[str, int], - 'post_nms_top_n': Dict[str, int], + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, + "pre_nms_top_n": Dict[str, int], + "post_nms_top_n": Dict[str, int], } - def __init__(self, - anchor_generator, - head, - # - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - # - pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0): + def __init__( + self, + anchor_generator, + head, + # + fg_iou_thresh, + bg_iou_thresh, + batch_size_per_image, + positive_fraction, + # + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + score_thresh=0.0, + ): super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -155,9 +151,7 @@ def __init__(self, allow_low_quality_matches=True, ) - self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( - batch_size_per_image, positive_fraction - ) + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) # used during testing self._pre_nms_top_n = pre_nms_top_n self._post_nms_top_n = post_nms_top_n @@ -167,13 +161,13 @@ def __init__(self, def pre_nms_top_n(self): if self.training: - return self._pre_nms_top_n['training'] - return self._pre_nms_top_n['testing'] + return self._pre_nms_top_n["training"] + return self._pre_nms_top_n["testing"] def post_nms_top_n(self): if self.training: - return self._post_nms_top_n['training'] - return self._post_nms_top_n['testing'] + return self._post_nms_top_n["training"] + return self._post_nms_top_n["testing"] def assign_targets_to_anchors(self, anchors, targets): # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] @@ -235,8 +229,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ objectness = objectness.reshape(num_images, -1) levels = [ - torch.full((n,), idx, dtype=torch.int64, device=device) - for idx, n in enumerate(num_anchors_per_level) + torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) @@ -271,7 +264,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) # keep only topk scoring predictions - keep = keep[:self.post_nms_top_n()] + keep = keep[: self.post_nms_top_n()] boxes, scores = boxes[keep], scores[keep] final_boxes.append(boxes) @@ -303,24 +296,26 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) - box_loss = F.smooth_l1_loss( - pred_bbox_deltas[sampled_pos_inds], - regression_targets[sampled_pos_inds], - beta=1 / 9, - reduction='sum', - ) / (sampled_inds.numel()) - - objectness_loss = F.binary_cross_entropy_with_logits( - objectness[sampled_inds], labels[sampled_inds] + box_loss = ( + F.smooth_l1_loss( + pred_bbox_deltas[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1 / 9, + reduction="sum", + ) + / (sampled_inds.numel()) ) + objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds]) + return objectness_loss, box_loss - def forward(self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): + def forward( + self, + images, # type: ImageList + features, # type: Dict[str, Tensor] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] """ Args: @@ -346,8 +341,7 @@ def forward(self, num_images = len(anchors) num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] - objectness, pred_bbox_deltas = \ - concat_box_prediction_layers(objectness, pred_bbox_deltas) + objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas) # apply pred_bbox_deltas to anchors to obtain the decoded proposals # note that we detach the deltas because Faster R-CNN do not backprop through # the proposals @@ -361,7 +355,8 @@ def forward(self, labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) loss_objectness, loss_rpn_box_reg = self.compute_loss( - objectness, pred_bbox_deltas, labels, regression_targets) + objectness, pred_bbox_deltas, labels, regression_targets + ) losses = { "loss_objectness": loss_objectness, "loss_rpn_box_reg": loss_rpn_box_reg, diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index e67c4930b30..800561ca25e 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -1,29 +1,29 @@ -import torch -import torch.nn.functional as F import warnings - from collections import OrderedDict -from torch import nn, Tensor from typing import Any, Dict, List, Optional, Tuple +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import boxes as box_ops +from .. import vgg from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -from .. import vgg -from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops import boxes as box_ops -__all__ = ['SSD', 'ssd300_vgg16'] +__all__ = ["SSD", "ssd300_vgg16"] model_urls = { - 'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', + "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", } backbone_urls = { # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - 'vgg16_features': 'https://download.pytorch.org/models/vgg16_features-amdegroot.pth' + "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" } @@ -43,8 +43,8 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: return { - 'bbox_regression': self.regression_head(x), - 'cls_logits': self.classification_head(x), + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), } @@ -159,31 +159,38 @@ class SSD(nn.Module): proposals used during the training of the classification head. It is used to estimate the negative to positive ratio. """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, } - def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, - size: Tuple[int, int], num_classes: int, - image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None, - head: Optional[nn.Module] = None, - score_thresh: float = 0.01, - nms_thresh: float = 0.45, - detections_per_img: int = 200, - iou_thresh: float = 0.5, - topk_candidates: int = 400, - positive_fraction: float = 0.25): + def __init__( + self, + backbone: nn.Module, + anchor_generator: DefaultBoxGenerator, + size: Tuple[int, int], + num_classes: int, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + head: Optional[nn.Module] = None, + score_thresh: float = 0.01, + nms_thresh: float = 0.45, + detections_per_img: int = 200, + iou_thresh: float = 0.5, + topk_candidates: int = 400, + positive_fraction: float = 0.25, + ): super().__init__() self.backbone = backbone self.anchor_generator = anchor_generator - self.box_coder = det_utils.BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) if head is None: - if hasattr(backbone, 'out_channels'): + if hasattr(backbone, "out_channels"): out_channels = backbone.out_channels else: out_channels = det_utils.retrieve_out_channels(backbone, size) @@ -200,8 +207,9 @@ def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] - self.transform = GeneralizedRCNNTransform(min(size), max(size), image_mean, image_std, - size_divisible=1, fixed_size=size) + self.transform = GeneralizedRCNNTransform( + min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size + ) self.score_thresh = score_thresh self.nms_thresh = nms_thresh @@ -213,45 +221,58 @@ def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, self._has_warned = False @torch.jit.unused - def eager_outputs(self, losses: Dict[str, Tensor], - detections: List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + def eager_outputs( + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training: return losses return detections - def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor], - matched_idxs: List[Tensor]) -> Dict[str, Tensor]: - bbox_regression = head_outputs['bbox_regression'] - cls_logits = head_outputs['cls_logits'] + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: + bbox_regression = head_outputs["bbox_regression"] + cls_logits = head_outputs["cls_logits"] # Match original targets with default boxes num_foreground = 0 bbox_loss = [] cls_targets = [] - for (targets_per_image, bbox_regression_per_image, cls_logits_per_image, anchors_per_image, - matched_idxs_per_image) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): + for ( + targets_per_image, + bbox_regression_per_image, + cls_logits_per_image, + anchors_per_image, + matched_idxs_per_image, + ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): # produce the matching between boxes and targets foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image] num_foreground += foreground_matched_idxs_per_image.numel() # Calculate regression loss - matched_gt_boxes_per_image = targets_per_image['boxes'][foreground_matched_idxs_per_image] + matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) - bbox_loss.append(torch.nn.functional.smooth_l1_loss( - bbox_regression_per_image, - target_regression, - reduction='sum' - )) + bbox_loss.append( + torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + ) # Estimate ground truth for class targets - gt_classes_target = torch.zeros((cls_logits_per_image.size(0), ), dtype=targets_per_image['labels'].dtype, - device=targets_per_image['labels'].device) - gt_classes_target[foreground_idxs_per_image] = \ - targets_per_image['labels'][foreground_matched_idxs_per_image] + gt_classes_target = torch.zeros( + (cls_logits_per_image.size(0),), + dtype=targets_per_image["labels"].dtype, + device=targets_per_image["labels"].device, + ) + gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][ + foreground_matched_idxs_per_image + ] cls_targets.append(gt_classes_target) bbox_loss = torch.stack(bbox_loss) @@ -259,30 +280,29 @@ def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, # Calculate classification loss num_classes = cls_logits.size(-1) - cls_loss = F.cross_entropy( - cls_logits.view(-1, num_classes), - cls_targets.view(-1), - reduction='none' - ).view(cls_targets.size()) + cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view( + cls_targets.size() + ) # Hard Negative Sampling foreground_idxs = cls_targets > 0 num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True) # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio negative_loss = cls_loss.clone() - negative_loss[foreground_idxs] = -float('inf') # use -inf to detect positive values that creeped in the sample + negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample values, idx = negative_loss.sort(1, descending=True) # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values)) background_idxs = idx.sort(1)[1] < num_negative N = max(1, num_foreground) return { - 'bbox_regression': bbox_loss.sum() / N, - 'classification': (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, + "bbox_regression": bbox_loss.sum() / N, + "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, } - def forward(self, images: List[Tensor], - targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + def forward( + self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training and targets is None: raise ValueError("In training mode, targets should be passed") @@ -292,12 +312,11 @@ def forward(self, images: List[Tensor], boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] @@ -317,14 +336,15 @@ def forward(self, images: List[Tensor], if degenerate_boxes.any(): bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) # get the features from the backbone features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) features = list(features.values()) @@ -341,12 +361,13 @@ def forward(self, images: List[Tensor], matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, - device=anchors_per_image.device)) + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) continue - match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) matched_idxs.append(self.proposal_matcher(match_quality_matrix)) losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs) @@ -361,10 +382,11 @@ def forward(self, images: List[Tensor], return losses, detections return self.eager_outputs(losses, detections) - def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], - image_shapes: List[Tuple[int, int]]) -> List[Dict[str, Tensor]]: - bbox_regression = head_outputs['bbox_regression'] - pred_scores = F.softmax(head_outputs['cls_logits'], dim=-1) + def postprocess_detections( + self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: + bbox_regression = head_outputs["bbox_regression"] + pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1) num_classes = pred_scores.size(-1) device = pred_scores.device @@ -400,13 +422,15 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: # non-maximum suppression keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) - keep = keep[:self.detections_per_img] - - detections.append({ - 'boxes': image_boxes[keep], - 'scores': image_scores[keep], - 'labels': image_labels[keep], - }) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) return detections @@ -423,45 +447,47 @@ def __init__(self, backbone: nn.Module, highres: bool): self.scale_weight = nn.Parameter(torch.ones(512) * 20) # Multiple Feature maps - page 4, Fig 2 of SSD paper - self.features = nn.Sequential( - *backbone[:maxpool4_pos] # until conv4_3 - ) + self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3 # SSD300 case - page 4, Fig 2 of SSD paper - extra = nn.ModuleList([ - nn.Sequential( - nn.Conv2d(1024, 256, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(512, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3), # conv10_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3), # conv11_2 - nn.ReLU(inplace=True), - ) - ]) + extra = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(1024, 256, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv10_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv11_2 + nn.ReLU(inplace=True), + ), + ] + ) if highres: # Additional layers for the SSD512 case. See page 11, footernote 5. - extra.append(nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=4), # conv12_2 - nn.ReLU(inplace=True), - )) + extra.append( + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=4), # conv12_2 + nn.ReLU(inplace=True), + ) + ) _xavier_init(extra) fc = nn.Sequential( @@ -469,13 +495,16 @@ def __init__(self, backbone: nn.Module, highres: bool): nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous nn.ReLU(inplace=True), nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7 - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) _xavier_init(fc) - extra.insert(0, nn.Sequential( - *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 - fc, - )) + extra.insert( + 0, + nn.Sequential( + *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 + fc, + ), + ) self.extra = extra def forward(self, x: Tensor) -> Dict[str, Tensor]: @@ -495,7 +524,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): if backbone_name in backbone_urls: # Use custom backbones more appropriate for SSD - arch = backbone_name.split('_')[0] + arch = backbone_name.split("_")[0] backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features if pretrained: state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) @@ -519,8 +548,14 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained return SSDFeatureExtractorVGG(backbone, highres) -def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, - pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): +def ssd300_vgg16( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +): """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. @@ -569,16 +604,19 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i warnings.warn("The size of the model is already fixed; ignoring the argument.") trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], - steps=[8, 16, 32, 64, 100, 300]) + anchor_generator = DefaultBoxGenerator( + [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], + steps=[8, 16, 32, 64, 100, 300], + ) defaults = { # Rescale the input in a way compatible to the backbone @@ -588,7 +626,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i kwargs = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) if pretrained: - weights_name = 'ssd300_vgg16_coco' + weights_name = "ssd300_vgg16_coco" if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 08d48c68020..b87fa2e2a2f 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,38 +1,42 @@ -import torch import warnings - from collections import OrderedDict from functools import partial -from torch import nn, Tensor -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional +import torch +from torch import Tensor, nn + +from ..._internally_replaced_utils import load_state_dict_from_url +from .. import mobilenet +from ..mobilenetv3 import ConvBNActivation from . import _utils as det_utils -from .ssd import SSD, SSDScoringHead from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers -from .. import mobilenet -from ..mobilenetv3 import ConvBNActivation -from ..._internally_replaced_utils import load_state_dict_from_url - +from .ssd import SSD, SSDScoringHead -__all__ = ['ssdlite320_mobilenet_v3_large'] +__all__ = ["ssdlite320_mobilenet_v3_large"] model_urls = { - 'ssdlite320_mobilenet_v3_large_coco': - 'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth' + "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" # noqa: E501 } # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper -def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, - norm_layer: Callable[..., nn.Module]) -> nn.Sequential: +def _prediction_block( + in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module] +) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, - norm_layer=norm_layer, activation_layer=nn.ReLU6), - + ConvBNActivation( + in_channels, + in_channels, + kernel_size=kernel_size, + groups=in_channels, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ), # 1x1 projetion to output channels - nn.Conv2d(in_channels, out_channels, 1) + nn.Conv2d(in_channels, out_channels, 1), ) @@ -41,16 +45,23 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), - + ConvBNActivation( + in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), # 3x3 depthwise with stride 2 and padding 1 - ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, - groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), - + ConvBNActivation( + intermediate_channels, + intermediate_channels, + kernel_size=3, + stride=2, + groups=intermediate_channels, + norm_layer=norm_layer, + activation_layer=activation, + ), # 1x1 projetion to output channels - ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), + ConvBNActivation( + intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), ) @@ -63,22 +74,24 @@ def _normal_init(conv: nn.Module): class SSDLiteHead(nn.Module): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, - norm_layer: Callable[..., nn.Module]): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): super().__init__() self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: return { - 'bbox_regression': self.regression_head(x), - 'cls_logits': self.classification_head(x), + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), } class SSDLiteClassificationHead(SSDScoringHead): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, - norm_layer: Callable[..., nn.Module]): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): cls_logits = nn.ModuleList() for channels, anchors in zip(in_channels, num_anchors): cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) @@ -96,24 +109,33 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C class SSDLiteFeatureExtractorMobileNet(nn.Module): - def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], width_mult: float = 1.0, - min_depth: int = 16, **kwargs: Any): + def __init__( + self, + backbone: nn.Module, + c4_pos: int, + norm_layer: Callable[..., nn.Module], + width_mult: float = 1.0, + min_depth: int = 16, + **kwargs: Any, + ): super().__init__() assert not backbone[c4_pos].use_res_connect self.features = nn.Sequential( # As described in section 6.3 of MobileNetV3 paper nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer - nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end + nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end ) get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 - extra = nn.ModuleList([ - _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), - _extra_block(get_depth(512), get_depth(256), norm_layer), - _extra_block(get_depth(256), get_depth(256), norm_layer), - _extra_block(get_depth(256), get_depth(128), norm_layer), - ]) + extra = nn.ModuleList( + [ + _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), + _extra_block(get_depth(512), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(128), norm_layer), + ] + ) _normal_init(extra) self.extra = extra @@ -132,10 +154,17 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, - norm_layer: Callable[..., nn.Module], **kwargs: Any): - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, - norm_layer=norm_layer, **kwargs).features +def _mobilenet_extractor( + backbone_name: str, + progress: bool, + pretrained: bool, + trainable_layers: int, + norm_layer: Callable[..., nn.Module], + **kwargs: Any, +): + backbone = mobilenet.__dict__[backbone_name]( + pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs + ).features if not pretrained: # Change the default initialization scheme if not pretrained _normal_init(backbone) @@ -156,10 +185,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) -def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, - pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any): +def ssdlite320_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = False, + trainable_backbone_layers: Optional[int] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, +): """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" `_ and @@ -188,7 +222,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru warnings.warn("The size of the model is already fixed; ignoring the argument.") trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6 + ) if pretrained: pretrained_backbone = False @@ -199,8 +234,15 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, - norm_layer, reduced_tail=reduce_tail, **kwargs) + backbone = _mobilenet_extractor( + "mobilenet_v3_large", + progress, + pretrained_backbone, + trainable_backbone_layers, + norm_layer, + reduced_tail=reduce_tail, + **kwargs, + ) size = (320, 320) anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) @@ -219,11 +261,17 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru "image_std": [0.5, 0.5, 0.5], } kwargs = {**defaults, **kwargs} - model = SSD(backbone, anchor_generator, size, num_classes, - head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) + model = SSD( + backbone, + anchor_generator, + size, + num_classes, + head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), + **kwargs, + ) if pretrained: - weights_name = 'ssdlite320_mobilenet_v3_large_coco' + weights_name = "ssdlite320_mobilenet_v3_large_coco" if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 0ca5273e047..860d9d76cd9 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,9 +1,9 @@ import math +from typing import Dict, List, Optional, Tuple + import torch import torchvision - -from torch import nn, Tensor -from typing import List, Tuple, Dict, Optional +from torch import Tensor, nn from .image_list import ImageList from .roi_heads import paste_masks_in_image @@ -13,6 +13,7 @@ def _get_shape_onnx(image): # type: (Tensor) -> Tensor from torch.onnx import operators + return operators.shape_as_tensor(image)[-2:] @@ -23,10 +24,13 @@ def _fake_cast_onnx(v): return v -def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: float, - target: Optional[Dict[str, Tensor]] = None, - fixed_size: Optional[Tuple[int, int]] = None, - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: +def _resize_image_and_masks( + image: Tensor, + self_min_size: float, + self_max_size: float, + target: Optional[Dict[str, Tensor]] = None, + fixed_size: Optional[Tuple[int, int]] = None, +) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if torchvision._is_tracing(): im_shape = _get_shape_onnx(image) else: @@ -48,16 +52,23 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: scale_factor = scale.item() recompute_scale_factor = True - image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear', - recompute_scale_factor=recompute_scale_factor, align_corners=False)[0] + image = torch.nn.functional.interpolate( + image[None], + size=size, + scale_factor=scale_factor, + mode="bilinear", + recompute_scale_factor=recompute_scale_factor, + align_corners=False, + )[0] if target is None: return image, target if "masks" in target: mask = target["masks"] - mask = torch.nn.functional.interpolate(mask[:, None].float(), size=size, scale_factor=scale_factor, - recompute_scale_factor=recompute_scale_factor)[:, 0].byte() + mask = torch.nn.functional.interpolate( + mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor + )[:, 0].byte() target["masks"] = mask return image, target @@ -85,10 +96,11 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, self.size_divisible = size_divisible self.fixed_size = fixed_size - def forward(self, - images, # type: List[Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): + def forward( + self, + images, # type: List[Tensor] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]] images = [img for img in images] if targets is not None: @@ -108,8 +120,9 @@ def forward(self, target_index = targets[i] if targets is not None else None if image.dim() != 3: - raise ValueError("images is expected to be a list of 3d tensors " - "of shape [C, H, W], got {}".format(image.shape)) + raise ValueError( + "images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape) + ) image = self.normalize(image) image, target_index = self.resize(image, target_index) images[i] = image @@ -144,13 +157,14 @@ def torch_choice(self, k): TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 is fixed. """ - index = int(torch.empty(1).uniform_(0., float(len(k))).item()) + index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) return k[index] - def resize(self, - image: Tensor, - target: Optional[Dict[str, Tensor]] = None, - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def resize( + self, + image: Tensor, + target: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] if self.training: size = float(self.torch_choice(self.min_size)) @@ -225,11 +239,12 @@ def batch_images(self, images, size_divisible=32): return batched_imgs - def postprocess(self, - result, # type: List[Dict[str, Tensor]] - image_shapes, # type: List[Tuple[int, int]] - original_image_sizes # type: List[Tuple[int, int]] - ): + def postprocess( + self, + result, # type: List[Dict[str, Tensor]] + image_shapes, # type: List[Tuple[int, int]] + original_image_sizes, # type: List[Tuple[int, int]] + ): # type: (...) -> List[Dict[str, Tensor]] if self.training: return result @@ -248,20 +263,21 @@ def postprocess(self, return result def __repr__(self): - format_string = self.__class__.__name__ + '(' - _indent = '\n ' + format_string = self.__class__.__name__ + "(" + _indent = "\n " format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) - format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size, - self.max_size) - format_string += '\n)' + format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format( + _indent, self.min_size, self.max_size + ) + format_string += "\n)" return format_string def resize_keypoints(keypoints, original_size, new_size): # type: (Tensor, List[int], List[int]) -> Tensor ratios = [ - torch.tensor(s, dtype=torch.float32, device=keypoints.device) / - torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) + torch.tensor(s, dtype=torch.float32, device=keypoints.device) + / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) for s, s_orig in zip(new_size, original_size) ] ratio_h, ratio_w = ratios @@ -279,8 +295,8 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) -> Tensor ratios = [ - torch.tensor(s, dtype=torch.float32, device=boxes.device) / - torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) + torch.tensor(s, dtype=torch.float32, device=boxes.device) + / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) for s, s_orig in zip(new_size, original_size) ] ratio_height, ratio_width = ratios diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 0745ef4eef6..32919c95c36 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,22 +1,23 @@ import warnings from collections import namedtuple +from typing import Any, Callable, List, Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Optional, Tuple, List, Callable, Any -__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"] +__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"] model_urls = { # GoogLeNet ported from TensorFlow - 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", } -GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) -GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor], - 'aux_logits1': Optional[Tensor]} +GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) +GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]} # Script annotations failed with _GoogleNetOutputs = namedtuple ... # _GoogLeNetOutputs set here for backwards compat @@ -37,19 +38,19 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' not in kwargs: - kwargs['aux_logits'] = False - if kwargs['aux_logits']: - warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' - 'so make sure to train them') - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False model = GoogLeNet(**kwargs) - state_dict = load_state_dict_from_url(model_urls['googlenet'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False @@ -61,7 +62,7 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> class GoogLeNet(nn.Module): - __constants__ = ['aux_logits', 'transform_input'] + __constants__ = ["aux_logits", "transform_input"] def __init__( self, @@ -69,15 +70,18 @@ def __init__( aux_logits: bool = True, transform_input: bool = False, init_weights: Optional[bool] = None, - blocks: Optional[List[Callable[..., nn.Module]]] = None + blocks: Optional[List[Callable[..., nn.Module]]] = None, ) -> None: super(GoogLeNet, self).__init__() if blocks is None: blocks = [BasicConv2d, Inception, InceptionAux] if init_weights is None: - warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of ' - 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' - ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) + warnings.warn( + "The default weight initialization of GoogleNet will be changed in future releases of " + "torchvision. If you wish to keep the old behavior (which leads to long initialization times" + " due to scipy/scipy#11299), please set init_weights=True.", + FutureWarning, + ) init_weights = True assert len(blocks) == 3 conv_block = blocks[0] @@ -125,6 +129,7 @@ def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats + X = stats.truncnorm(-2, 2, scale=0.01) values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) values = values.view(m.weight.size()) @@ -202,7 +207,7 @@ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> Goog if self.training and self.aux_logits: return _GoogLeNetOutputs(x, aux2, aux1) else: - return x # type: ignore[return-value] + return x # type: ignore[return-value] def forward(self, x: Tensor) -> GoogLeNetOutputs: x = self._transform_input(x) @@ -217,7 +222,6 @@ def forward(self, x: Tensor) -> GoogLeNetOutputs: class Inception(nn.Module): - def __init__( self, in_channels: int, @@ -227,7 +231,7 @@ def __init__( ch5x5red: int, ch5x5: int, pool_proj: int, - conv_block: Optional[Callable[..., nn.Module]] = None + conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Inception, self).__init__() if conv_block is None: @@ -235,20 +239,19 @@ def __init__( self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) self.branch2 = nn.Sequential( - conv_block(in_channels, ch3x3red, kernel_size=1), - conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1) + conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1) ) self.branch3 = nn.Sequential( conv_block(in_channels, ch5x5red, kernel_size=1), # Here, kernel_size=3 instead of kernel_size=5 is a known bug. # Please see https://github.com/pytorch/vision/issues/906 for details. - conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1) + conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1), ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), - conv_block(in_channels, pool_proj, kernel_size=1) + conv_block(in_channels, pool_proj, kernel_size=1), ) def _forward(self, x: Tensor) -> List[Tensor]: @@ -266,12 +269,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionAux(nn.Module): - def __init__( - self, - in_channels: int, - num_classes: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionAux, self).__init__() if conv_block is None: @@ -300,13 +299,7 @@ def forward(self, x: Tensor) -> Tensor: class BasicConv2d(nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - **kwargs: Any - ) -> None: + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index b9c6ab74534..2bc5335a9e0 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,22 +1,23 @@ -from collections import namedtuple import warnings +from collections import namedtuple +from typing import Any, Callable, List, Optional, Tuple + import torch -from torch import nn, Tensor import torch.nn.functional as F -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Callable, Any, Optional, Tuple, List +from torch import Tensor, nn +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs'] +__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"] model_urls = { # Inception v3 ported from TensorFlow - 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth', + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", } -InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) -InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]} +InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) +InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} # Script annotations failed with _GoogleNetOutputs = namedtuple ... # _InceptionOutputs set here for backwards compat @@ -41,17 +42,16 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' in kwargs: - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" in kwargs: + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True else: original_aux_logits = True - kwargs['init_weights'] = False # we are loading weights from a pretrained model + kwargs["init_weights"] = False # we are loading weights from a pretrained model model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False @@ -62,25 +62,24 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) class Inception3(nn.Module): - def __init__( self, num_classes: int = 1000, aux_logits: bool = True, transform_input: bool = False, inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, - init_weights: Optional[bool] = None + init_weights: Optional[bool] = None, ) -> None: super(Inception3, self).__init__() if inception_blocks is None: - inception_blocks = [ - BasicConv2d, InceptionA, InceptionB, InceptionC, - InceptionD, InceptionE, InceptionAux - ] + inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] if init_weights is None: - warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ' - 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' - ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) + warnings.warn( + "The default weight initialization of inception_v3 will be changed in future releases of " + "torchvision. If you wish to keep the old behavior (which leads to long initialization times" + " due to scipy/scipy#11299), please set init_weights=True.", + FutureWarning, + ) init_weights = True assert len(inception_blocks) == 7 conv_block = inception_blocks[0] @@ -121,7 +120,8 @@ def __init__( for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats - stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + + stddev = m.stddev if hasattr(m, "stddev") else 0.1 X = stats.truncnorm(-2, 2, scale=stddev) values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) values = values.view(m.weight.size()) @@ -213,12 +213,8 @@ def forward(self, x: Tensor) -> InceptionOutputs: class InceptionA(nn.Module): - def __init__( - self, - in_channels: int, - pool_features: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionA, self).__init__() if conv_block is None: @@ -256,12 +252,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionB(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionB, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -289,12 +280,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionC(nn.Module): - def __init__( - self, - in_channels: int, - channels_7x7: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionC, self).__init__() if conv_block is None: @@ -339,12 +326,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionD(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionD, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -375,12 +357,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionE(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionE, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -427,12 +404,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionAux(nn.Module): - def __init__( - self, - in_channels: int, - num_classes: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionAux, self).__init__() if conv_block is None: @@ -462,13 +435,7 @@ def forward(self, x: Tensor) -> Tensor: class BasicConv2d(nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - **kwargs: Any - ) -> None: + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index ffefab77628..3f48f82c41e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,20 +1,19 @@ import warnings +from typing import Any, Dict, List import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any, Dict, List -__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] +__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"] _MODEL_URLS = { - "mnasnet0_5": - "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", "mnasnet0_75": None, - "mnasnet1_0": - "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - "mnasnet1_3": None + "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + "mnasnet1_3": None, } # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is @@ -23,34 +22,27 @@ class _InvertedResidual(nn.Module): - def __init__( - self, - in_ch: int, - out_ch: int, - kernel_size: int, - stride: int, - expansion_factor: int, - bn_momentum: float = 0.1 + self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1 ) -> None: super(_InvertedResidual, self).__init__() assert stride in [1, 2] assert kernel_size in [3, 5] mid_ch = in_ch * expansion_factor - self.apply_residual = (in_ch == out_ch and stride == 1) + self.apply_residual = in_ch == out_ch and stride == 1 self.layers = nn.Sequential( # Pointwise nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Depthwise - nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, - stride=stride, groups=mid_ch, bias=False), + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Linear pointwise. Note that there's no activation. nn.Conv2d(mid_ch, out_ch, 1, bias=False), - nn.BatchNorm2d(out_ch, momentum=bn_momentum)) + nn.BatchNorm2d(out_ch, momentum=bn_momentum), + ) def forward(self, input: Tensor) -> Tensor: if self.apply_residual: @@ -59,39 +51,37 @@ def forward(self, input: Tensor) -> Tensor: return self.layers(input) -def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, - bn_momentum: float) -> nn.Sequential: - """ Creates a stack of inverted residuals. """ +def _stack( + in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float +) -> nn.Sequential: + """Creates a stack of inverted residuals.""" assert repeats >= 1 # First one has no skip, because feature map size changes. - first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, - bn_momentum=bn_momentum) + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum) remaining = [] for _ in range(1, repeats): - remaining.append( - _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, - bn_momentum=bn_momentum)) + remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum)) return nn.Sequential(first, *remaining) def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: - """ Asymmetric rounding to make `val` divisible by `divisor`. With default + """Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the - smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.""" assert 0.0 < round_up_bias < 1.0 new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) return new_val if new_val >= round_up_bias * val else new_val + divisor def _get_depths(alpha: float) -> List[int]: - """ Scales tensor depths as in reference MobileNet code, prefers rouding up - rather than down. """ + """Scales tensor depths as in reference MobileNet code, prefers rouding up + rather than down.""" depths = [32, 16, 24, 40, 80, 96, 192, 320] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] class MNASNet(torch.nn.Module): - """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This + """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This implements the B1 variant of the model. >>> model = MNASNet(1.0, num_classes=1000) >>> x = torch.rand(1, 3, 224, 224) @@ -101,15 +91,11 @@ class MNASNet(torch.nn.Module): >>> y.nelement() 1000 """ + # Version 2 adds depth scaling in the initial stages of the network. _version = 2 - def __init__( - self, - alpha: float, - num_classes: int = 1000, - dropout: float = 0.2 - ) -> None: + def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None: super(MNASNet, self).__init__() assert alpha > 0.0 self.alpha = alpha @@ -121,8 +107,7 @@ def __init__( nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), # Depthwise separable, no skip. - nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, - groups=depths[0], bias=False), + nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), @@ -140,8 +125,7 @@ def __init__( nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), - nn.Linear(1280, num_classes)) + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)) self._initialize_weights() def forward(self, x: Tensor) -> Tensor: @@ -153,20 +137,26 @@ def forward(self, x: Tensor) -> Tensor: def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", - nonlinearity="relu") + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode="fan_out", - nonlinearity="sigmoid") + nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid") nn.init.zeros_(m.bias) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, - missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: + def _load_from_state_dict( + self, + state_dict: Dict, + prefix: str, + local_metadata: Dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: version = local_metadata.get("version", None) assert version in [1, 2] @@ -180,8 +170,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: D nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), - nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, - bias=False), + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), @@ -199,20 +188,19 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: D "This checkpoint will load and work as before, but " "you may want to upgrade by training a newer model or " "transfer learning from an updated ImageNet checkpoint.", - UserWarning) + UserWarning, + ) super(MNASNet, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: - raise ValueError( - "No checkpoint is available for model type {}".format(model_name)) + raise ValueError("No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict( - load_state_dict_from_url(checkpoint_url, progress=progress)) + model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 4108305d3f5..c5e897e2d2b 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,4 +1,2 @@ -from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all - -__all__ = mv2_all + mv3_all +from .mobilenetv2 import * +from .mobilenetv3 import * diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 0cfa4f371e3..cac2e06106d 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,15 +1,15 @@ +from typing import Any, Callable, List, Optional + import torch -from torch import nn -from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Callable, Any, Optional, List +from torch import Tensor, nn +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['MobileNetV2', 'mobilenet_v2'] +__all__ = ["MobileNetV2", "mobilenet_v2"] model_urls = { - 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", } @@ -47,10 +47,11 @@ def __init__( if activation_layer is None: activation_layer = nn.ReLU6 super().__init__( - nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, - bias=False), + nn.Conv2d( + in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False + ), norm_layer(out_planes), - activation_layer(inplace=True) + activation_layer(inplace=True), ) self.out_channels = out_planes @@ -61,12 +62,7 @@ def __init__( class InvertedResidual(nn.Module): def __init__( - self, - inp: int, - oup: int, - stride: int, - expand_ratio: int, - norm_layer: Optional[Callable[..., nn.Module]] = None + self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super(InvertedResidual, self).__init__() self.stride = stride @@ -82,13 +78,15 @@ def __init__( if expand_ratio != 1: # pw layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) - layers.extend([ - # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - norm_layer(oup), - ]) + layers.extend( + [ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ] + ) self.conv = nn.Sequential(*layers) self.out_channels = oup self._is_cn = stride > 1 @@ -108,7 +106,7 @@ def __init__( inverted_residual_setting: Optional[List[List[int]]] = None, round_nearest: int = 8, block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: """ MobileNet V2 main class @@ -148,8 +146,10 @@ def __init__( # only check the first element, assuming user knows t,c,n,s are required if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: - raise ValueError("inverted_residual_setting should be non-empty " - "or a 4-element list, got {}".format(inverted_residual_setting)) + raise ValueError( + "inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting) + ) # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) @@ -176,7 +176,7 @@ def __init__( # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -211,7 +211,6 @@ def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) """ model = MobileNetV2(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ebe3f510a49..6da38cd176e 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,13 +1,12 @@ -import torch - from functools import partial -from torch import nn, Tensor +from typing import Any, Callable, List, Optional, Sequence + +import torch +from torch import Tensor, nn from torch.nn import functional as F -from typing import Any, Callable, Dict, List, Optional, Sequence +from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible from .._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation - __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] @@ -41,8 +40,18 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper - def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, - activation: str, stride: int, dilation: int, width_mult: float): + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: int, + dilation: int, + width_mult: float, + ): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) @@ -59,11 +68,15 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): # Implemented as described at section 5 of MobileNetV3 paper - def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation): + def __init__( + self, + cnf: InvertedResidualConfig, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation, + ): super().__init__() if not (1 <= cnf.stride <= 2): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels @@ -72,20 +85,43 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod # expand if cnf.expanded_channels != cnf.input_channels: - layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + cnf.input_channels, + cnf.expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.expanded_channels, + kernel_size=cnf.kernel, + stride=stride, + dilation=cnf.dilation, + groups=cnf.expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) if cnf.use_se: layers.append(se_layer(cnf.expanded_channels)) # project - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.Identity)) + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.out_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Identity, + ) + ) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels @@ -99,15 +135,14 @@ def forward(self, input: Tensor) -> Tensor: class MobileNetV3(nn.Module): - def __init__( - self, - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, ) -> None: """ MobileNet V3 main class @@ -123,8 +158,10 @@ def __init__( if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") - elif not (isinstance(inverted_residual_setting, Sequence) and - all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) + ): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: @@ -137,8 +174,16 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=nn.Hardswish)) + layers.append( + ConvBNActivation( + 3, + firstconv_output_channels, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) # building inverted residual blocks for cnf in inverted_residual_setting: @@ -147,8 +192,15 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels - layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=nn.Hardswish)) + layers.append( + ConvBNActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -161,7 +213,7 @@ def __init__( for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -185,8 +237,9 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, - **kwargs: Any): +def _mobilenet_v3_conf( + arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any +): reduce_divider = 2 if reduced_tail else 1 dilation = 2 if dilated else 1 @@ -239,7 +292,7 @@ def _mobilenet_v3_model( last_channel: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained: diff --git a/torchvision/models/quantization/__init__.py b/torchvision/models/quantization/__init__.py index deae997a219..da8bbba3567 100644 --- a/torchvision/models/quantization/__init__.py +++ b/torchvision/models/quantization/__init__.py @@ -1,5 +1,5 @@ -from .mobilenet import * -from .resnet import * from .googlenet import * from .inception import * +from .mobilenet import * +from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index bc1477d8f65..cea3b853b04 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -1,20 +1,18 @@ import warnings + import torch import torch.nn as nn from torch.nn import functional as F +from torchvision.models.googlenet import BasicConv2d, GoogLeNet, GoogLeNetOutputs, Inception, InceptionAux, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.googlenet import ( - GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls) - from .utils import _replace_relu, quantize_model - -__all__ = ['QuantizableGoogLeNet', 'googlenet'] +__all__ = ["QuantizableGoogLeNet", "googlenet"] quant_model_urls = { # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch - 'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth', + "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", } @@ -36,35 +34,35 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' not in kwargs: - kwargs['aux_logits'] = False - if kwargs['aux_logits']: - warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' - 'so make sure to train them') - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls['googlenet' + '_' + backend] + model_url = quant_model_urls["googlenet" + "_" + backend] else: - model_url = model_urls['googlenet'] + model_url = model_urls["googlenet"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -76,7 +74,6 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): class QuantizableBasicConv2d(BasicConv2d): - def __init__(self, *args, **kwargs): super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) self.relu = nn.ReLU() @@ -92,10 +89,8 @@ def fuse_model(self): class QuantizableInception(Inception): - def __init__(self, *args, **kwargs): - super(QuantizableInception, self).__init__( - conv_block=QuantizableBasicConv2d, *args, **kwargs) + super(QuantizableInception, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) self.cat = nn.quantized.FloatFunctional() def forward(self, x): @@ -104,10 +99,8 @@ def forward(self, x): class QuantizableInceptionAux(InceptionAux): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionAux, self).__init__( - conv_block=QuantizableBasicConv2d, *args, **kwargs) + super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.7) @@ -130,12 +123,9 @@ def forward(self, x): class QuantizableGoogLeNet(GoogLeNet): - def __init__(self, *args, **kwargs): super(QuantizableGoogLeNet, self).__init__( - blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], - *args, - **kwargs + blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 833d8fb8b75..db305ed097a 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -5,10 +5,10 @@ import torch.nn.functional as F from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs + from ..._internally_replaced_utils import load_state_dict_from_url from .utils import _replace_relu, quantize_model - __all__ = [ "QuantizableInception3", "inception_v3", @@ -17,8 +17,7 @@ quant_model_urls = { # fp32 weights ported from TensorFlow, quantized in PyTorch - "inception_v3_google_fbgemm": - "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" + "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" # noqa: E501 } @@ -57,7 +56,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] @@ -67,12 +66,11 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): if not original_aux_logits: model.aux_logits = False model.AuxLogits = None - model_url = quant_model_urls['inception_v3_google' + '_' + backend] + model_url = quant_model_urls["inception_v3_google" + "_" + backend] else: - model_url = inception_module.model_urls['inception_v3_google'] + model_url = inception_module.model_urls["inception_v3_google"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -189,8 +187,8 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): QuantizableInceptionC, QuantizableInceptionD, QuantizableInceptionE, - QuantizableInceptionAux - ] + QuantizableInceptionAux, + ], ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 8f2c42db640..c5e897e2d2b 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1,4 +1,2 @@ -from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all - -__all__ = mv2_all + mv3_all +from .mobilenetv2 import * +from .mobilenetv3 import * diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 857d919b1fa..23d55f7e4af 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,15 +1,14 @@ from torch import nn +from torch.quantization import DeQuantStub, QuantStub, fuse_modules +from torchvision.models.mobilenetv2 import ConvBNReLU, InvertedResidual, MobileNetV2, model_urls + from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls -from torch.quantization import QuantStub, DeQuantStub, fuse_modules from .utils import _replace_relu, quantize_model - -__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] +__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] quant_model_urls = { - 'mobilenet_v2_qnnpack': - 'https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth' + "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" } @@ -51,7 +50,7 @@ def forward(self, x): def fuse_model(self): for m in self.modules(): if type(m) == ConvBNReLU: - fuse_modules(m, ['0', '1', '2'], inplace=True) + fuse_modules(m, ["0", "1", "2"], inplace=True) if type(m) == QuantizableInvertedResidual: m.fuse_model() @@ -76,19 +75,18 @@ def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs): if quantize: # TODO use pretrained as a string to specify the backend - backend = 'qnnpack' + backend = "qnnpack" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls['mobilenet_v2_' + backend] + model_url = quant_model_urls["mobilenet_v2_" + backend] else: - model_url = model_urls['mobilenet_v2'] + model_url = model_urls["mobilenet_v2"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 5462af89127..d8f09d8baae 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,18 +1,25 @@ +from typing import Any, List, Optional + import torch -from torch import nn, Tensor +from torch import Tensor, nn +from torch.quantization import DeQuantStub, QuantStub, fuse_modules +from torchvision.models.mobilenetv3 import ( + ConvBNActivation, + InvertedResidual, + InvertedResidualConfig, + MobileNetV3, + SqueezeExcitation, + _mobilenet_v3_conf, + model_urls, +) + from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ - SqueezeExcitation, model_urls, _mobilenet_v3_conf -from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from typing import Any, List, Optional from .utils import _replace_relu - -__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large'] +__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] quant_model_urls = { - 'mobilenet_v3_large_qnnpack': - "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", # noqa: E501 } @@ -25,7 +32,7 @@ def forward(self, input: Tensor) -> Tensor: return self.skip_mul.mul(self._scale(input, False), input) def fuse_model(self): - fuse_modules(self, ['fc1', 'relu'], inplace=True) + fuse_modules(self, ["fc1", "relu"], inplace=True) class QuantizableInvertedResidual(InvertedResidual): @@ -61,9 +68,9 @@ def forward(self, x): def fuse_model(self): for m in self.modules(): if type(m) == ConvBNActivation: - modules_to_fuse = ['0', '1'] + modules_to_fuse = ["0", "1"] if type(m[2]) == nn.ReLU: - modules_to_fuse.append('2') + modules_to_fuse.append("2") fuse_modules(m, modules_to_fuse, inplace=True) elif type(m) == QuantizableSqueezeExcitation: m.fuse_model() @@ -88,20 +95,20 @@ def _mobilenet_v3_model( pretrained: bool, progress: bool, quantize: bool, - **kwargs: Any + **kwargs: Any, ): model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - backend = 'qnnpack' + backend = "qnnpack" model.fuse_model() model.qconfig = torch.quantization.get_default_qat_qconfig(backend) torch.quantization.prepare_qat(model, inplace=True) if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) + _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) torch.quantization.convert(model, inplace=True) model.eval() diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 2f3f50e8013..354374f638c 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,21 +1,18 @@ import torch -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls import torch.nn as nn -from ..._internally_replaced_utils import load_state_dict_from_url from torch.quantization import fuse_modules +from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls + +from ..._internally_replaced_utils import load_state_dict_from_url from .utils import _replace_relu, quantize_model -__all__ = ['QuantizableResNet', 'resnet18', 'resnet50', - 'resnext101_32x8d'] +__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] quant_model_urls = { - 'resnet18_fbgemm': - 'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth', - 'resnet50_fbgemm': - 'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth', - 'resnext101_32x8d_fbgemm': - 'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth', + "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", } @@ -42,10 +39,9 @@ def forward(self, x): return out def fuse_model(self): - torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], - ['conv2', 'bn2']], inplace=True) + torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) class QuantizableBottleneck(Bottleneck): @@ -74,15 +70,12 @@ def forward(self, x): return out def fuse_model(self): - fuse_modules(self, [['conv1', 'bn1', 'relu1'], - ['conv2', 'bn2', 'relu2'], - ['conv3', 'bn3']], inplace=True) + fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) class QuantizableResNet(ResNet): - def __init__(self, *args, **kwargs): super(QuantizableResNet, self).__init__(*args, **kwargs) @@ -106,7 +99,7 @@ def fuse_model(self): and the model after modification is in floating point """ - fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True) + fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True) for m in self.modules(): if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock: m.fuse_model() @@ -117,19 +110,18 @@ def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs): _replace_relu(model) if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls[arch + '_' + backend] + model_url = quant_model_urls[arch + "_" + backend] else: model_url = model_urls[arch] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model @@ -144,8 +136,7 @@ def resnet18(pretrained=False, progress=True, quantize=False, **kwargs): progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, - quantize, **kwargs) + return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): @@ -157,8 +148,7 @@ def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, - quantize, **kwargs) + return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs): @@ -170,7 +160,6 @@ def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs): progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3], - pretrained, progress, quantize, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 17885015772..3878589a334 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,23 +1,23 @@ import torch import torch.nn as nn +from torchvision.models import shufflenetv2 + from ..._internally_replaced_utils import load_state_dict_from_url -import torchvision.models.shufflenetv2 -import sys from .utils import _replace_relu, quantize_model -shufflenetv2 = sys.modules['torchvision.models.shufflenetv2'] - __all__ = [ - 'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', - 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' + "QuantizableShuffleNetV2", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", ] quant_model_urls = { - 'shufflenetv2_x0.5_fbgemm': None, - 'shufflenetv2_x1.0_fbgemm': - 'https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth', - 'shufflenetv2_x1.5_fbgemm': None, - 'shufflenetv2_x2.0_fbgemm': None, + "shufflenetv2_x0.5_fbgemm": None, + "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", + "shufflenetv2_x1.5_fbgemm": None, + "shufflenetv2_x2.0_fbgemm": None, } @@ -64,9 +64,7 @@ def fuse_model(self): for m in self.modules(): if type(m) == QuantizableInvertedResidual: if len(m.branch1._modules.items()) > 0: - torch.quantization.fuse_modules( - m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True - ) + torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True) torch.quantization.fuse_modules( m.branch2, [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], @@ -80,19 +78,18 @@ def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs): if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls[arch + '_' + backend] + model_url = quant_model_urls[arch + "_" + backend] else: model_url = shufflenetv2.model_urls[arch] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model @@ -109,8 +106,9 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, quantize, - [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + ) def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs): @@ -124,8 +122,9 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, quantize, - [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + ) def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs): @@ -139,8 +138,9 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, quantize, - [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x1.5", pretrained, progress, quantize, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs + ) def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs): @@ -154,5 +154,6 @@ def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, quantize, - [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2( + "shufflenetv2_x2.0", pretrained, progress, quantize, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs + ) diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index bf23c9a9332..d1477a3a0b4 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -23,14 +23,15 @@ def quantize_model(model, backend): torch.backends.quantized.engine = backend model.eval() # Make sure that weight qconfig matches that of the serialized models - if backend == 'fbgemm': + if backend == "fbgemm": model.qconfig = torch.quantization.QConfig( activation=torch.quantization.default_observer, - weight=torch.quantization.default_per_channel_weight_observer) - elif backend == 'qnnpack': + weight=torch.quantization.default_per_channel_weight_observer, + ) + elif backend == "qnnpack": model.qconfig = torch.quantization.QConfig( - activation=torch.quantization.default_observer, - weight=torch.quantization.default_weight_observer) + activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer + ) model.fuse_model() torch.quantization.prepare(model, inplace=True) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 6d708767441..33a44daf978 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,32 +1,50 @@ +from typing import Any, Callable, List, Optional, Type, Union + import torch -from torch import Tensor import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Type, Any, Callable, Union, List, Optional +from torch import Tensor +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] +__all__ = [ + "ResNet", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", +] model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", } def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: @@ -46,13 +64,13 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -101,12 +119,12 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) @@ -142,7 +160,6 @@ def forward(self, x: Tensor) -> Tensor: class ResNet(nn.Module): - def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], @@ -152,7 +169,7 @@ def __init__( groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: @@ -166,28 +183,26 @@ def __init__( # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -202,8 +217,14 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, - stride: int = 1, dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -217,13 +238,23 @@ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, b ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) @@ -255,12 +286,11 @@ def _resnet( layers: List[int], pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -273,8 +303,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) + return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -285,8 +314,7 @@ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) + return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -297,8 +325,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -309,8 +336,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) + return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -321,8 +347,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) + return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -333,10 +358,9 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -347,10 +371,9 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -366,9 +389,8 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) + kwargs["width_per_group"] = 64 * 2 + return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -384,6 +406,5 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) + kwargs["width_per_group"] = 64 * 2 + return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index fb6633d7fb5..62b564da486 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,4 +1,4 @@ -from .segmentation import * -from .fcn import * from .deeplabv3 import * +from .fcn import * from .lraspp import * +from .segmentation import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 176b7490038..16677e86411 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -5,7 +5,7 @@ class _SimpleSegmentationModel(nn.Module): - __constants__ = ['aux_classifier'] + __constants__ = ["aux_classifier"] def __init__(self, backbone, classifier, aux_classifier=None): super(_SimpleSegmentationModel, self).__init__() @@ -21,13 +21,13 @@ def forward(self, x): result = OrderedDict() x = features["out"] x = self.classifier(x) - x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["out"] = x if self.aux_classifier is not None: x = features["aux"] x = self.aux_classifier(x) - x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["aux"] = x return result diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 7acc013ccb1..843e715dbd6 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -4,7 +4,6 @@ from ._utils import _SimpleSegmentationModel - __all__ = ["DeepLabV3"] @@ -23,6 +22,7 @@ class DeepLabV3(_SimpleSegmentationModel): the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ + pass @@ -33,7 +33,7 @@ def __init__(self, in_channels, num_classes): nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), - nn.Conv2d(256, num_classes, 1) + nn.Conv2d(256, num_classes, 1), ) @@ -42,7 +42,7 @@ def __init__(self, in_channels, out_channels, dilation): modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), - nn.ReLU() + nn.ReLU(), ] super(ASPPConv, self).__init__(*modules) @@ -53,23 +53,23 @@ def __init__(self, in_channels, out_channels): nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), - nn.ReLU()) + nn.ReLU(), + ) def forward(self, x): size = x.shape[-2:] for mod in self: x = mod(x) - return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, atrous_rates, out_channels=256): super(ASPP, self).__init__() modules = [] - modules.append(nn.Sequential( - nn.Conv2d(in_channels, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU())) + modules.append( + nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) + ) rates = tuple(atrous_rates) for rate in rates: @@ -83,7 +83,8 @@ def __init__(self, in_channels, atrous_rates, out_channels=256): nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), - nn.Dropout(0.5)) + nn.Dropout(0.5), + ) def forward(self, x): res = [] diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 3c695b53167..9b86e433b66 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -2,7 +2,6 @@ from ._utils import _SimpleSegmentationModel - __all__ = ["FCN"] @@ -19,6 +18,7 @@ class FCN(_SimpleSegmentationModel): the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ + pass @@ -30,7 +30,7 @@ def __init__(self, in_channels, channels): nn.BatchNorm2d(inter_channels), nn.ReLU(), nn.Dropout(0.1), - nn.Conv2d(inter_channels, channels, 1) + nn.Conv2d(inter_channels, channels, 1), ] super(FCNHead, self).__init__(*layers) diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 44cd9b1e773..b3ce32abc16 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,9 +1,8 @@ from collections import OrderedDict - -from torch import nn, Tensor -from torch.nn import functional as F from typing import Dict +from torch import Tensor, nn +from torch.nn import functional as F __all__ = ["LRASPP"] @@ -32,7 +31,7 @@ def __init__(self, backbone, low_channels, high_channels, num_classes, inter_cha def forward(self, input): features = self.backbone(input) out = self.classifier(features) - out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) + out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False) result = OrderedDict() result["out"] = out @@ -41,13 +40,12 @@ def forward(self, input): class LRASPPHead(nn.Module): - def __init__(self, low_channels, high_channels, num_classes, inter_channels): super().__init__() self.cbr = nn.Sequential( nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) self.scale = nn.Sequential( nn.AdaptiveAvgPool2d(1), @@ -64,6 +62,6 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: x = self.cbr(high) s = self.scale(high) x = x * s - x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False) + x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) return self.low_classifier(low) + self.high_classifier(x) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 7b3a0258ddb..55c6d9fe541 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,37 +1,40 @@ -from .._utils import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url -from .. import mobilenetv3 -from .. import resnet +from .. import mobilenetv3, resnet +from .._utils import IntermediateLayerGetter from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead from .lraspp import LRASPP - -__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', - 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large'] +__all__ = [ + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", +] model_urls = { - 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth', - 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', - 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', - 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', - 'deeplabv3_mobilenet_v3_large_coco': - 'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth', - 'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth', + "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", # noqa: E501 + "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", } def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True): - if 'resnet' in backbone_name: + if "resnet" in backbone_name: backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained_backbone, - replace_stride_with_dilation=[False, True, True]) - out_layer = 'layer4' + pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True] + ) + out_layer = "layer4" out_inplanes = 2048 - aux_layer = 'layer3' + aux_layer = "layer3" aux_inplanes = 1024 - elif 'mobilenet_v3' in backbone_name: + elif "mobilenet_v3" in backbone_name: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -44,11 +47,11 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True) aux_layer = str(aux_pos) aux_inplanes = backbone[aux_pos].out_channels else: - raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) + raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name)) - return_layers = {out_layer: 'out'} + return_layers = {out_layer: "out"} if aux: - return_layers[aux_layer] = 'aux' + return_layers[aux_layer] = "aux" backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None @@ -56,8 +59,8 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True) aux_classifier = FCNHead(aux_inplanes, num_classes) model_map = { - 'deeplabv3': (DeepLabHead, DeepLabV3), - 'fcn': (FCNHead, FCN), + "deeplabv3": (DeepLabHead, DeepLabV3), + "fcn": (FCNHead, FCN), } classifier = model_map[name][0](out_inplanes, num_classes) base_model = model_map[name][1] @@ -77,10 +80,10 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss def _load_weights(model, arch_type, backbone, progress): - arch = arch_type + '_' + backbone + '_coco' + arch = arch_type + "_" + backbone + "_coco" model_url = model_urls.get(arch, None) if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -97,14 +100,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'}) + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) model = LRASPP(backbone, low_channels, high_channels, num_classes) return model -def fcn_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: @@ -114,11 +116,10 @@ def fcn_resnet50(pretrained=False, progress=True, num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) -def fcn_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def fcn_resnet101(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: @@ -128,11 +129,10 @@ def fcn_resnet101(pretrained=False, progress=True, num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: @@ -142,11 +142,10 @@ def deeplabv3_resnet50(pretrained=False, progress=True, num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_resnet101(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_resnet101(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: @@ -156,11 +155,10 @@ def deeplabv3_resnet101(pretrained=False, progress=True, num_classes (int): The number of classes aux_loss (bool): If True, include an auxiliary classifier """ - return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) -def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): +def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: @@ -170,7 +168,7 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs) def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): @@ -183,12 +181,12 @@ def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, * num_classes (int): number of output classes of the model (including the background) """ if kwargs.pop("aux_loss", False): - raise NotImplementedError('This model does not use auxiliary loss') + raise NotImplementedError("This model does not use auxiliary loss") - backbone_name = 'mobilenet_v3_large' + backbone_name = "mobilenet_v3_large" model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) if pretrained: - _load_weights(model, 'lraspp', backbone_name, progress) + _load_weights(model, "lraspp", backbone_name, progress) return model diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 65d60a09e6c..cc2b109cd46 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,20 +1,18 @@ +from typing import Any, Callable, List + import torch -from torch import Tensor import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Callable, Any, List +from torch import Tensor +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = [ - 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', - 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' -] +__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] model_urls = { - 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', - 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', - 'shufflenetv2_x1.5': None, - 'shufflenetv2_x2.0': None, + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": None, + "shufflenetv2_x2.0": None, } @@ -23,8 +21,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: channels_per_group = num_channels // groups # reshape - x = x.view(batchsize, groups, - channels_per_group, height, width) + x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() @@ -35,16 +32,11 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: class InvertedResidual(nn.Module): - def __init__( - self, - inp: int, - oup: int, - stride: int - ) -> None: + def __init__(self, inp: int, oup: int, stride: int) -> None: super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.stride = stride branch_features = oup // 2 @@ -62,8 +54,14 @@ def __init__( self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( - nn.Conv2d(inp if (self.stride > 1) else branch_features, - branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), @@ -75,12 +73,7 @@ def __init__( @staticmethod def depthwise_conv( - i: int, - o: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - bias: bool = False + i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False ) -> nn.Conv2d: return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) @@ -102,14 +95,14 @@ def __init__( stages_repeats: List[int], stages_out_channels: List[int], num_classes: int = 1000, - inverted_residual: Callable[..., nn.Module] = InvertedResidual + inverted_residual: Callable[..., nn.Module] = InvertedResidual, ) -> None: super(ShuffleNetV2, self).__init__() if len(stages_repeats) != 3: - raise ValueError('expected stages_repeats as list of 3 positive ints') + raise ValueError("expected stages_repeats as list of 3 positive ints") if len(stages_out_channels) != 5: - raise ValueError('expected stages_out_channels as list of 5 positive ints') + raise ValueError("expected stages_out_channels as list of 5 positive ints") self._stage_out_channels = stages_out_channels input_channels = 3 @@ -127,9 +120,8 @@ def __init__( self.stage2: nn.Sequential self.stage3: nn.Sequential self.stage4: nn.Sequential - stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] - for name, repeats, output_channels in zip( - stage_names, stages_repeats, self._stage_out_channels[1:]): + stage_names = ["stage{}".format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]): seq = [inverted_residual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(inverted_residual(output_channels, output_channels, 1)) @@ -167,7 +159,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa if pretrained: model_url = model_urls[arch] if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -185,8 +177,7 @@ def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, - [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -199,8 +190,7 @@ def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, - [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -213,8 +203,7 @@ def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, - [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -227,5 +216,4 @@ def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, - [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index c54e475d412..e6258502da0 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,55 +1,42 @@ +from typing import Any + import torch import torch.nn as nn import torch.nn.init as init + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any -__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] +__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] model_urls = { - 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth', - 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth', + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", } class Fire(nn.Module): - - def __init__( - self, - inplanes: int, - squeeze_planes: int, - expand1x1_planes: int, - expand3x3_planes: int - ) -> None: + def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None: super(Fire, self).__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze_activation = nn.ReLU(inplace=True) - self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, - kernel_size=1) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) self.expand1x1_activation = nn.ReLU(inplace=True) - self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, - kernel_size=3, padding=1) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.squeeze_activation(self.squeeze(x)) - return torch.cat([ - self.expand1x1_activation(self.expand1x1(x)), - self.expand3x3_activation(self.expand3x3(x)) - ], 1) + return torch.cat( + [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1 + ) class SqueezeNet(nn.Module): - - def __init__( - self, - version: str = '1_0', - num_classes: int = 1000 - ) -> None: + def __init__(self, version: str = "1_0", num_classes: int = 1000) -> None: super(SqueezeNet, self).__init__() self.num_classes = num_classes - if version == '1_0': + if version == "1_0": self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), @@ -65,7 +52,7 @@ def __init__( nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) - elif version == '1_1': + elif version == "1_1": self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.ReLU(inplace=True), @@ -85,16 +72,12 @@ def __init__( # FIXME: Is this needed? SqueezeNet should only be called from the # FIXME: squeezenet1_x() functions # FIXME: This checking is not done for the other models - raise ValueError("Unsupported SqueezeNet version {version}:" - "1_0 or 1_1 expected".format(version=version)) + raise ValueError("Unsupported SqueezeNet version {version}:" "1_0 or 1_1 expected".format(version=version)) # Final convolution is initialized differently from the rest final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) self.classifier = nn.Sequential( - nn.Dropout(p=0.5), - final_conv, - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)) + nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) ) for m in self.modules(): @@ -115,9 +98,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: model = SqueezeNet(version, **kwargs) if pretrained: - arch = 'squeezenet' + version - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + arch = "squeezenet" + version + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -132,7 +114,7 @@ def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet('1_0', pretrained, progress, **kwargs) + return _squeezenet("1_0", pretrained, progress, **kwargs) def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: @@ -146,4 +128,4 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet('1_1', pretrained, progress, **kwargs) + return _squeezenet("1_1", pretrained, progress, **kwargs) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 619bce97b2f..7a62264b217 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,35 +1,37 @@ +from typing import Any, Dict, List, Union, cast + import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url -from typing import Union, List, Dict, Any, cast +from .._internally_replaced_utils import load_state_dict_from_url __all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', + "VGG", + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19_bn", + "vgg19", ] model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', - 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', - 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', - 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', - 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", } class VGG(nn.Module): - - def __init__( - self, - features: nn.Module, - num_classes: int = 1000, - init_weights: bool = True - ) -> None: + def __init__(self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True) -> None: super(VGG, self).__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) @@ -55,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): @@ -70,7 +72,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ layers: List[nn.Module] = [] in_channels = 3 for v in cfg: - if v == 'M': + if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: v = cast(int, v) @@ -84,20 +86,19 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ cfgs: Dict[str, List[Union[str, int]]] = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: if pretrained: - kwargs['init_weights'] = False + kwargs["init_weights"] = False model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -111,7 +112,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) + return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -123,7 +124,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) + return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -135,7 +136,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) + return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -147,7 +148,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) + return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -159,7 +160,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) + return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -171,7 +172,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) + return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -183,7 +184,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) + return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -195,4 +196,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) + return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index fc69188ef7a..51eb7c2e31f 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -2,23 +2,17 @@ from ..._internally_replaced_utils import load_state_dict_from_url - -__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] +__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"] model_urls = { - 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', - 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', - 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', + "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", } class Conv3DSimple(nn.Conv3d): - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): + def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): super(Conv3DSimple, self).__init__( in_channels=in_planes, @@ -26,7 +20,8 @@ def __init__(self, kernel_size=(3, 3, 3), stride=stride, padding=padding, - bias=False) + bias=False, + ) @staticmethod def get_downsample_stride(stride): @@ -34,22 +29,22 @@ def get_downsample_stride(stride): class Conv2Plus1D(nn.Sequential): - - def __init__(self, - in_planes, - out_planes, - midplanes, - stride=1, - padding=1): + def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): super(Conv2Plus1D, self).__init__( - nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), - stride=(1, stride, stride), padding=(0, padding, padding), - bias=False), + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), nn.BatchNorm3d(midplanes), nn.ReLU(inplace=True), - nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), - stride=(stride, 1, 1), padding=(padding, 0, 0), - bias=False)) + nn.Conv3d( + midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False + ), + ) @staticmethod def get_downsample_stride(stride): @@ -57,13 +52,7 @@ def get_downsample_stride(stride): class Conv3DNoTemporal(nn.Conv3d): - - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): + def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): super(Conv3DNoTemporal, self).__init__( in_channels=in_planes, @@ -71,7 +60,8 @@ def __init__(self, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), - bias=False) + bias=False, + ) @staticmethod def get_downsample_stride(stride): @@ -87,14 +77,9 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Sequential( - conv_builder(inplanes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) - ) - self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes), - nn.BatchNorm3d(planes) + conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) + self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes)) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -123,21 +108,17 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): # 1x1x1 self.conv1 = nn.Sequential( - nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # Second kernel self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) + conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # 1x1x1 self.conv3 = nn.Sequential( nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), - nn.BatchNorm3d(planes * self.expansion) + nn.BatchNorm3d(planes * self.expansion), ) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -160,38 +141,32 @@ def forward(self, x): class BasicStem(nn.Sequential): - """The default conv-batchnorm-relu stem - """ + """The default conv-batchnorm-relu stem""" + def __init__(self): super(BasicStem, self).__init__( - nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), - padding=(1, 3, 3), bias=False), + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False), nn.BatchNorm3d(64), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) class R2Plus1dStem(nn.Sequential): - """R(2+1)D stem is different than the default one as it uses separated 3D convolution - """ + """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" + def __init__(self): super(R2Plus1dStem, self).__init__( - nn.Conv3d(3, 45, kernel_size=(1, 7, 7), - stride=(1, 2, 2), padding=(0, 3, 3), - bias=False), + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False), nn.BatchNorm3d(45), nn.ReLU(inplace=True), - nn.Conv3d(45, 64, kernel_size=(3, 1, 1), - stride=(1, 1, 1), padding=(1, 0, 0), - bias=False), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), nn.BatchNorm3d(64), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) class VideoResNet(nn.Module): - - def __init__(self, block, conv_makers, layers, - stem, num_classes=400, - zero_init_residual=False): + def __init__(self, block, conv_makers, layers, stem, num_classes=400, zero_init_residual=False): """Generic resnet video generator. Args: @@ -244,9 +219,8 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1): if stride != 1 or self.inplanes != planes * block.expansion: ds_stride = conv_builder.get_downsample_stride(stride) downsample = nn.Sequential( - nn.Conv3d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=ds_stride, bias=False), - nn.BatchNorm3d(planes * block.expansion) + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) @@ -260,8 +234,7 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1): def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv3d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', - nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm3d): @@ -276,8 +249,7 @@ def _video_resnet(arch, pretrained=False, progress=True, **kwargs): model = VideoResNet(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -294,12 +266,16 @@ def r3d_18(pretrained=False, progress=True, **kwargs): nn.Module: R3D-18 network """ - return _video_resnet('r3d_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] * 4, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs) + return _video_resnet( + "r3d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) def mc3_18(pretrained=False, progress=True, **kwargs): @@ -313,12 +289,16 @@ def mc3_18(pretrained=False, progress=True, **kwargs): Returns: nn.Module: MC3 Network definition """ - return _video_resnet('mc3_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs) + return _video_resnet( + "mc3_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) def r2plus1d_18(pretrained=False, progress=True, **kwargs): @@ -332,9 +312,13 @@ def r2plus1d_18(pretrained=False, progress=True, **kwargs): Returns: nn.Module: R(2+1)D-18 network """ - return _video_resnet('r2plus1d_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv2Plus1D] * 4, - layers=[2, 2, 2, 2], - stem=R2Plus1dStem, **kwargs) + return _video_resnet( + "r2plus1d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[2, 2, 2, 2], + stem=R2Plus1dStem, + **kwargs, + ) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 0ec189dbc2a..83e94a8ed48 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,24 +1,46 @@ -from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou -from .boxes import box_convert -from .deform_conv import deform_conv2d, DeformConv2d -from .roi_align import roi_align, RoIAlign -from .roi_pool import roi_pool, RoIPool -from .ps_roi_align import ps_roi_align, PSRoIAlign -from .ps_roi_pool import ps_roi_pool, PSRoIPool -from .poolers import MultiScaleRoIAlign +from ._register_onnx_ops import _register_custom_op +from .boxes import ( + batched_nms, + box_area, + box_convert, + box_iou, + clip_boxes_to_image, + generalized_box_iou, + nms, + remove_small_boxes, +) +from .deform_conv import DeformConv2d, deform_conv2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss - -from ._register_onnx_ops import _register_custom_op +from .poolers import MultiScaleRoIAlign +from .ps_roi_align import PSRoIAlign, ps_roi_align +from .ps_roi_pool import PSRoIPool, ps_roi_pool +from .roi_align import RoIAlign, roi_align +from .roi_pool import RoIPool, roi_pool _register_custom_op() __all__ = [ - 'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes', - 'clip_boxes_to_image', 'box_convert', - 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', - 'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', - 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', - 'sigmoid_focal_loss' + "deform_conv2d", + "DeformConv2d", + "nms", + "batched_nms", + "remove_small_boxes", + "clip_boxes_to_image", + "box_convert", + "box_area", + "box_iou", + "generalized_box_iou", + "roi_align", + "RoIAlign", + "roi_pool", + "RoIPool", + "ps_roi_align", + "PSRoIAlign", + "ps_roi_pool", + "PSRoIPool", + "MultiScaleRoIAlign", + "FeaturePyramidNetwork", + "sigmoid_focal_loss", ] diff --git a/torchvision/ops/_register_onnx_ops.py b/torchvision/ops/_register_onnx_ops.py index 8e8ed331803..173b65162a4 100644 --- a/torchvision/ops/_register_onnx_ops.py +++ b/torchvision/ops/_register_onnx_ops.py @@ -1,50 +1,64 @@ import sys -import torch import warnings +import torch + _onnx_opset_version = 11 def _register_custom_op(): - from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx, scalar_type_to_pytorch_type, \ - cast_pytorch_to_onnx - from torch.onnx.symbolic_opset9 import select, unsqueeze, squeeze, _cast_Long, reshape + from torch.onnx.symbolic_helper import parse_args + from torch.onnx.symbolic_opset9 import _cast_Long, select, squeeze, unsqueeze - @parse_args('v', 'v', 'f') + @parse_args("v", "v", "f") def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): boxes = unsqueeze(g, boxes, 0) scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) - max_output_per_class = g.op('Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long)) - iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) - nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold) - return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1) + max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) + iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) + nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) + return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) - @parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i') + @parse_args("v", "v", "f", "i", "i", "i", "i") def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant', - value_t=torch.tensor([0], dtype=torch.long))), 1), False) - rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) + batch_indices = _cast_Long( + g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False + ) + rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) if aligned: - warnings.warn("ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes," - " ONNX forces ROIs to be 1x1 or larger.") + warnings.warn( + "ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes," + " ONNX forces ROIs to be 1x1 or larger." + ) scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float) rois = g.op("Sub", rois, scale) # ONNX doesn't support negative sampling_ratio if sampling_ratio < 0: - warnings.warn("ONNX doesn't support negative sampling ratio," - "therefore is is set to 0 in order to be exported.") + warnings.warn( + "ONNX doesn't support negative sampling ratio," "therefore is is set to 0 in order to be exported." + ) sampling_ratio = 0 - return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale, - output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio) + return g.op( + "RoiAlign", + input, + rois, + batch_indices, + spatial_scale_f=spatial_scale, + output_height_i=pooled_height, + output_width_i=pooled_width, + sampling_ratio_i=sampling_ratio, + ) - @parse_args('v', 'v', 'f', 'i', 'i') + @parse_args("v", "v", "f", "i", "i") def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): - roi_pool = g.op('MaxRoiPool', input, rois, - pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale) + roi_pool = g.op( + "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale + ) return roi_pool, None from torch.onnx import register_custom_op_symbolic - register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version) - register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version) - register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version) + + register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version) + register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version) + register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index bb6287ad616..62508850777 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,6 +1,7 @@ +from typing import List + import torch from torch import Tensor -from typing import List def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: @@ -27,10 +28,11 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: def check_roi_boxes_shape(boxes: Tensor): if isinstance(boxes, (list, tuple)): for _tensor in boxes: - assert _tensor.size(1) == 4, \ - 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' + assert ( + _tensor.size(1) == 4 + ), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" elif isinstance(boxes, torch.Tensor): - assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' + assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]" else: - assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' + assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" return diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index c1f176f4da9..6723b2298a1 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -1,10 +1,12 @@ -import torch -from torch import Tensor from typing import Tuple -from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh + +import torch import torchvision +from torch import Tensor from torchvision.extension import _assert_has_ops +from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh + def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ @@ -183,13 +185,13 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: if in_fmt == out_fmt: return boxes.clone() - if in_fmt != 'xyxy' and out_fmt != 'xyxy': + if in_fmt != "xyxy" and out_fmt != "xyxy": # convert to xyxy and change in_fmt xyxy if in_fmt == "xywh": boxes = _box_xywh_to_xyxy(boxes) elif in_fmt == "cxcywh": boxes = _box_cxcywh_to_xyxy(boxes) - in_fmt = 'xyxy' + in_fmt = "xyxy" if in_fmt == "xyxy": if out_fmt == "xywh": diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index 4399d441843..0f4521f021b 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -1,11 +1,11 @@ import math +from typing import Optional, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import init -from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair -from typing import Optional, Tuple +from torch.nn.parameter import Parameter from torchvision.extension import _assert_has_ops @@ -84,7 +84,9 @@ def deform_conv2d( "the shape of the offset tensor at dimension 1 is not valid. It should " "be a multiple of 2 * weight.size[2] * weight.size[3].\n" "Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format( - offset.shape[1], 2 * weights_h * weights_w)) + offset.shape[1], 2 * weights_h * weights_w + ) + ) return torch.ops.torchvision.deform_conv2d( input, @@ -92,12 +94,16 @@ def deform_conv2d( offset, mask, bias, - stride_h, stride_w, - pad_h, pad_w, - dil_h, dil_w, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, - use_mask,) + use_mask, + ) class DeformConv2d(nn.Module): @@ -119,9 +125,9 @@ def __init__( super(DeformConv2d, self).__init__() if in_channels % groups != 0: - raise ValueError('in_channels must be divisible by groups') + raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: - raise ValueError('out_channels must be divisible by groups') + raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels @@ -131,13 +137,14 @@ def __init__( self.dilation = _pair(dilation) self.groups = groups - self.weight = Parameter(torch.empty(out_channels, in_channels // groups, - self.kernel_size[0], self.kernel_size[1])) + self.weight = Parameter( + torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1]) + ) if bias: self.bias = Parameter(torch.empty(out_channels)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() @@ -160,18 +167,26 @@ def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) out_height, out_width]): masks to be applied for each position in the convolution kernel. """ - return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, - padding=self.padding, dilation=self.dilation, mask=mask) + return deform_conv2d( + input, + offset, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + mask=mask, + ) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += '{in_channels}' - s += ', {out_channels}' - s += ', kernel_size={kernel_size}' - s += ', stride={stride}' - s += ', padding={padding}' if self.padding != (0, 0) else '' - s += ', dilation={dilation}' if self.dilation != (1, 1) else '' - s += ', groups={groups}' if self.groups != 1 else '' - s += ', bias=False' if self.bias is None else '' - s += ')' + s = self.__class__.__name__ + "(" + s += "{in_channels}" + s += ", {out_channels}" + s += ", kernel_size={kernel_size}" + s += ", stride={stride}" + s += ", padding={padding}" if self.padding != (0, 0) else "" + s += ", dilation={dilation}" if self.dilation != (1, 1) else "" + s += ", groups={groups}" if self.groups != 1 else "" + s += ", bias=False" if self.bias is None else "" + s += ")" return s.format(**self.__dict__) diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 7d72769ab07..5bbd973135a 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -1,9 +1,8 @@ from collections import OrderedDict +from typing import Dict, List, Optional, Tuple import torch.nn.functional as F -from torch import nn, Tensor - -from typing import Tuple, List, Dict, Optional +from torch import Tensor, nn class ExtraFPNBlock(nn.Module): @@ -21,6 +20,7 @@ class ExtraFPNBlock(nn.Module): of the FPN names (List[str]): the extended set of names for the results """ + def forward( self, results: List[Tensor], @@ -67,6 +67,7 @@ class FeaturePyramidNetwork(nn.Module): >>> ('feat3', torch.Size([1, 5, 8, 8]))] """ + def __init__( self, in_channels_list: List[int], @@ -165,6 +166,7 @@ class LastLevelMaxPool(ExtraFPNBlock): """ Applies a max_pool2d on top of the last feature map """ + def forward( self, x: List[Tensor], @@ -180,6 +182,7 @@ class LastLevelP6P7(ExtraFPNBlock): """ This module is used in RetinaNet to generate extra layers, P6 and P7. """ + def __init__(self, in_channels: int, out_channels: int): super(LastLevelP6P7, self).__init__() self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index de18f30c83a..3f72273c39c 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -31,9 +31,7 @@ def sigmoid_focal_loss( Loss tensor with the reduction option applied. """ p = torch.sigmoid(inputs) - ce_loss = F.binary_cross_entropy_with_logits( - inputs, targets, reduction="none" - ) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 7e43caa78d6..c86b71aa36b 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -9,9 +9,10 @@ """ import warnings +from typing import List, Optional + import torch from torch import Tensor -from typing import List, Optional class Conv2d(torch.nn.Conv2d): @@ -19,7 +20,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.Conv2d is deprecated and will be " - "removed in future versions, use torch.nn.Conv2d instead.", FutureWarning) + "removed in future versions, use torch.nn.Conv2d instead.", + FutureWarning, + ) class ConvTranspose2d(torch.nn.ConvTranspose2d): @@ -27,7 +30,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.ConvTranspose2d is deprecated and will be " - "removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning) + "removed in future versions, use torch.nn.ConvTranspose2d instead.", + FutureWarning, + ) class BatchNorm2d(torch.nn.BatchNorm2d): @@ -35,7 +40,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.BatchNorm2d is deprecated and will be " - "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) + "removed in future versions, use torch.nn.BatchNorm2d instead.", + FutureWarning, + ) interpolate = torch.nn.functional.interpolate @@ -56,8 +63,7 @@ def __init__( ): # n=None for backward-compatibility if n is not None: - warnings.warn("`n` argument is deprecated and has been renamed `num_features`", - DeprecationWarning) + warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps @@ -76,13 +82,13 @@ def _load_from_state_dict( unexpected_keys: List[str], error_msgs: List[str], ): - num_batches_tracked_key = prefix + 'num_batches_tracked' + num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def forward(self, x: Tensor) -> Tensor: # move reshapes to the beginning diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index f4ff289299b..73a0bee6846 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,11 +1,11 @@ -import torch -from torch import nn, Tensor +from typing import Dict, List, Optional, Tuple, Union +import torch import torchvision -from torchvision.ops import roi_align +from torch import Tensor, nn from torchvision.ops.boxes import box_area -from typing import Optional, List, Dict, Tuple, Union +from .roi_align import roi_align # copying result_idx_in_level to a specific index in result[] @@ -16,15 +16,17 @@ def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor: first_result = unmerged_results[0] dtype, device = first_result.dtype, first_result.device - res = torch.zeros((levels.size(0), first_result.size(1), - first_result.size(2), first_result.size(3)), - dtype=dtype, device=device) + res = torch.zeros( + (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device + ) for level in range(len(unmerged_results)): index = torch.where(levels == level)[0].view(-1, 1, 1, 1) - index = index.expand(index.size(0), - unmerged_results[level].size(1), - unmerged_results[level].size(2), - unmerged_results[level].size(3)) + index = index.expand( + index.size(0), + unmerged_results[level].size(1), + unmerged_results[level].size(2), + unmerged_results[level].size(3), + ) res = res.scatter(0, index, unmerged_results[level]) return res @@ -116,10 +118,7 @@ class MultiScaleRoIAlign(nn.Module): """ - __annotations__ = { - 'scales': Optional[List[float]], - 'map_levels': Optional[LevelMapper] - } + __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]} def __init__( self, @@ -224,10 +223,11 @@ def forward( if num_levels == 1: return roi_align( - x_filtered[0], rois, + x_filtered[0], + rois, output_size=self.output_size, spatial_scale=scales[0], - sampling_ratio=self.sampling_ratio + sampling_ratio=self.sampling_ratio, ) mapper = self.map_levels @@ -240,7 +240,11 @@ def forward( dtype, device = x_filtered[0].dtype, x_filtered[0].device result = torch.zeros( - (num_rois, num_channels,) + self.output_size, + ( + num_rois, + num_channels, + ) + + self.output_size, dtype=dtype, device=device, ) @@ -251,9 +255,12 @@ def forward( rois_per_level = rois[idx_in_level] result_idx_in_level = roi_align( - per_level_feature, rois_per_level, + per_level_feature, + rois_per_level, output_size=self.output_size, - spatial_scale=scale, sampling_ratio=self.sampling_ratio) + spatial_scale=scale, + sampling_ratio=self.sampling_ratio, + ) if torchvision._is_tracing(): tracing_results.append(result_idx_in_level.to(dtype)) @@ -273,5 +280,7 @@ def forward( return result def __repr__(self) -> str: - return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " - f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})") + return ( + f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " + f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})" + ) diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index d42353e2b0d..0205fa45df1 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -1,10 +1,9 @@ import torch -from torch import nn, Tensor - +from torch import Tensor, nn from torch.nn.modules.utils import _pair - from torchvision.extension import _assert_has_ops -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape + +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_align( @@ -47,10 +46,9 @@ def ps_roi_align( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale, - output_size[0], - output_size[1], - sampling_ratio) + output, _ = torch.ops.torchvision.ps_roi_align( + input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) return output @@ -58,6 +56,7 @@ class PSRoIAlign(nn.Module): """ See :func:`ps_roi_align`. """ + def __init__( self, output_size: int, @@ -70,13 +69,12 @@ def __init__( self.sampling_ratio = sampling_ratio def forward(self, input: Tensor, rois: Tensor) -> Tensor: - return ps_roi_align(input, rois, self.output_size, self.spatial_scale, - self.sampling_ratio) + return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index d0331e557fd..5bbf414fb18 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -1,10 +1,9 @@ import torch -from torch import nn, Tensor - +from torch import Tensor, nn from torch.nn.modules.utils import _pair - from torchvision.extension import _assert_has_ops -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape + +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_pool( @@ -41,9 +40,7 @@ def ps_roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, - output_size[0], - output_size[1]) + output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) return output @@ -51,6 +48,7 @@ class PSRoIPool(nn.Module): """ See :func:`ps_roi_pool`. """ + def __init__(self, output_size: int, spatial_scale: float): super(PSRoIPool, self).__init__() self.output_size = output_size @@ -60,8 +58,8 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index b589089aa42..130e06de8a5 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -1,11 +1,10 @@ import torch -from torch import nn, Tensor - -from torch.nn.modules.utils import _pair +from torch import Tensor, nn from torch.jit.annotations import BroadcastingList2 - +from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape + +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_align( @@ -52,15 +51,16 @@ def roi_align( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - return torch.ops.torchvision.roi_align(input, rois, spatial_scale, - output_size[0], output_size[1], - sampling_ratio, aligned) + return torch.ops.torchvision.roi_align( + input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned + ) class RoIAlign(nn.Module): """ See :func:`roi_align`. """ + def __init__( self, output_size: BroadcastingList2[int], @@ -78,10 +78,10 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) - tmpstr += ', aligned=' + str(self.aligned) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 90f2dd3d173..2764da3cbfe 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -1,11 +1,10 @@ import torch -from torch import nn, Tensor - -from torch.nn.modules.utils import _pair +from torch import Tensor, nn from torch.jit.annotations import BroadcastingList2 - +from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape + +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_pool( @@ -41,8 +40,7 @@ def roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, - output_size[0], output_size[1]) + output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) return output @@ -50,6 +48,7 @@ class RoIPool(nn.Module): """ See :func:`roi_pool`. """ + def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): super(RoIPool, self).__init__() self.output_size = output_size @@ -59,8 +58,8 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" return tmpstr diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 77680a14f0d..5b9513b27bc 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,2 +1,2 @@ -from .transforms import * from .autoaugment import * +from .transforms import * diff --git a/torchvision/transforms/_functional_video.py b/torchvision/transforms/_functional_video.py index 9eba0463a4f..5d3e0d373ca 100644 --- a/torchvision/transforms/_functional_video.py +++ b/torchvision/transforms/_functional_video.py @@ -1,10 +1,8 @@ -import torch import warnings +import torch -warnings.warn( - "The _functional_video module is deprecated. Please use the functional module instead." -) +warnings.warn("The _functional_video module is deprecated. Please use the functional module instead.") def _is_tensor_video_clip(clip): @@ -23,14 +21,12 @@ def crop(clip, i, j, h, w): clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) """ assert len(clip.size()) == 4, "clip should be a 4D tensor" - return clip[..., i:i + h, j:j + w] + return clip[..., i : i + h, j : j + w] def resize(clip, target_size, interpolation_mode): assert len(target_size) == 2, "target size should be tuple (height, width)" - return torch.nn.functional.interpolate( - clip, size=target_size, mode=interpolation_mode, align_corners=False - ) + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index bfef1b440d1..16b227301b2 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -4,14 +4,10 @@ import random import warnings -from torchvision.transforms import ( - RandomCrop, - RandomResizedCrop, -) +from torchvision.transforms import RandomCrop, RandomResizedCrop from . import _functional_video as F - __all__ = [ "RandomCropVideo", "RandomResizedCropVideo", @@ -22,9 +18,7 @@ ] -warnings.warn( - "The _transforms_video module is deprecated. Please use the transforms module instead." -) +warnings.warn("The _transforms_video module is deprecated. Please use the transforms module instead.") class RandomCropVideo(RandomCrop): @@ -46,7 +40,7 @@ def __call__(self, clip): return F.crop(clip, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class RandomResizedCropVideo(RandomResizedCrop): @@ -79,10 +73,9 @@ def __call__(self, clip): return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) def __repr__(self): - return self.__class__.__name__ + \ - '(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format( - self.size, self.interpolation_mode, self.scale, self.ratio - ) + return self.__class__.__name__ + "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format( + self.size, self.interpolation_mode, self.scale, self.ratio + ) class CenterCropVideo(object): @@ -103,7 +96,7 @@ def __call__(self, clip): return F.center_crop(clip, self.crop_size) def __repr__(self): - return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size) + return self.__class__.__name__ + "(crop_size={0})".format(self.crop_size) class NormalizeVideo(object): @@ -128,8 +121,7 @@ def __call__(self, clip): return F.normalize(clip, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( - self.mean, self.std, self.inplace) + return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(self.mean, self.std, self.inplace) class ToTensorVideo(object): diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 97522945d2e..d730ae600f5 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -1,11 +1,11 @@ import math -import torch - from enum import Enum +from typing import List, Optional, Tuple + +import torch from torch import Tensor -from typing import List, Tuple, Optional -from . import functional as F, InterpolationMode +from . import functional as F __all__ = ["AutoAugmentPolicy", "AutoAugment"] @@ -14,6 +14,7 @@ class AutoAugmentPolicy(Enum): """AutoAugment policies learned on different datasets. Available policies are IMAGENET, CIFAR10 and SVHN. """ + IMAGENET = "imagenet" CIFAR10 = "cifar10" SVHN = "svhn" @@ -144,8 +145,12 @@ class AutoAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: F.InterpolationMode = F.InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ): super().__init__() self.policy = policy self.interpolation = interpolation @@ -191,23 +196,54 @@ def forward(self, img: Tensor): for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: magnitudes, signed = self._get_op_meta(op_name) - magnitude = float(magnitudes[magnitude_id].item()) \ - if magnitudes is not None and magnitude_id is not None else 0.0 + magnitude = ( + float(magnitudes[magnitude_id].item()) + if magnitudes is not None and magnitude_id is not None + else 0.0 + ) if signed is not None and signed and signs[i] == 0: magnitude *= -1.0 if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=self.interpolation, fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=self.interpolation, + fill=fill, + ) elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=self.interpolation, fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=self.interpolation, + fill=fill, + ) elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[int(F._get_image_size(img)[0] * magnitude), 0], + scale=1.0, + interpolation=self.interpolation, + shear=[0.0, 0.0], + fill=fill, + ) elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, int(F._get_image_size(img)[1] * magnitude)], + scale=1.0, + interpolation=self.interpolation, + shear=[0.0, 0.0], + fill=fill, + ) elif op_name == "Rotate": img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) elif op_name == "Brightness": @@ -234,4 +270,4 @@ def forward(self, img: Tensor): return img def __repr__(self): - return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + return self.__class__.__name__ + "(policy={}, fill={})".format(self.policy, self.fill) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index dc3a9f8f68b..2c60540b3cd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,27 +2,27 @@ import numbers import warnings from enum import Enum +from typing import Any, List, Optional, Tuple import numpy as np from PIL import Image import torch from torch import Tensor -from typing import List, Tuple, Any, Optional + +from . import functional_pil as F_pil, functional_tensor as F_t try: import accimage except ImportError: accimage = None -from . import functional_pil as F_pil -from . import functional_tensor as F_t - class InterpolationMode(Enum): """Interpolation modes Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. """ + NEAREST = "nearest" BILINEAR = "bilinear" BICUBIC = "bicubic" @@ -59,8 +59,7 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode: def _get_image_size(img: Tensor) -> List[int]: - """Returns image size as [w, h] - """ + """Returns image size as [w, h]""" if isinstance(img, torch.Tensor): return F_t._get_image_size(img) @@ -68,8 +67,7 @@ def _get_image_size(img: Tensor) -> List[int]: def _get_image_num_channels(img: Tensor) -> int: - """Returns number of image channels - """ + """Returns number of image channels""" if isinstance(img, torch.Tensor): return F_t._get_image_num_channels(img) @@ -98,11 +96,11 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): + raise TypeError("pic should be PIL Image or ndarray. Got {}".format(type(pic))) if _is_numpy(pic) and not _is_numpy_image(pic): - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim)) default_float_dtype = torch.get_default_dtype() @@ -124,12 +122,10 @@ def to_tensor(pic): return torch.from_numpy(nppic).to(dtype=default_float_dtype) # handle PIL Image - mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} - img = torch.from_numpy( - np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True) - ) + mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32} + img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)) - if pic.mode == '1': + if pic.mode == "1": img = 255 * img img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format @@ -153,7 +149,7 @@ def pil_to_tensor(pic): Tensor: Converted image. """ if not F_pil._is_pil_image(pic): - raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) + raise TypeError("pic should be PIL Image. Got {}".format(type(pic))) if accimage is not None and isinstance(pic, accimage.Image): # accimage format is always uint8 internally, so always return uint8 here @@ -192,7 +188,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - of the integer ``dtype``. """ if not isinstance(image, torch.Tensor): - raise TypeError('Input img should be Tensor Image') + raise TypeError("Input img should be Tensor Image") return F_t.convert_image_dtype(image, dtype) @@ -211,12 +207,12 @@ def to_pil_image(pic, mode=None): Returns: PIL Image: Image converted to PIL Image. """ - if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): - raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): + raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic))) elif isinstance(pic, torch.Tensor): if pic.ndimension() not in {2, 3}: - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndimension())) elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) @@ -224,11 +220,11 @@ def to_pil_image(pic, mode=None): # check number of channels if pic.shape[-3] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3])) + raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-3])) elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim)) elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) @@ -236,58 +232,58 @@ def to_pil_image(pic, mode=None): # check number of channels if pic.shape[-1] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1])) + raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-1])) npimg = pic if isinstance(pic, torch.Tensor): - if pic.is_floating_point() and mode != 'F': + if pic.is_floating_point() and mode != "F": pic = pic.mul(255).byte() npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) if not isinstance(npimg, np.ndarray): - raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + - 'not {}'.format(type(npimg))) + raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, " + "not {}".format(type(npimg))) if npimg.shape[2] == 1: expected_mode = None npimg = npimg[:, :, 0] if npimg.dtype == np.uint8: - expected_mode = 'L' + expected_mode = "L" elif npimg.dtype == np.int16: - expected_mode = 'I;16' + expected_mode = "I;16" elif npimg.dtype == np.int32: - expected_mode = 'I' + expected_mode = "I" elif npimg.dtype == np.float32: - expected_mode = 'F' + expected_mode = "F" if mode is not None and mode != expected_mode: - raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" - .format(mode, np.dtype, expected_mode)) + raise ValueError( + "Incorrect mode ({}) supplied for input type {}. Should be {}".format(mode, np.dtype, expected_mode) + ) mode = expected_mode elif npimg.shape[2] == 2: - permitted_2_channel_modes = ['LA'] + permitted_2_channel_modes = ["LA"] if mode is not None and mode not in permitted_2_channel_modes: raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'LA' + mode = "LA" elif npimg.shape[2] == 4: - permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] + permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"] if mode is not None and mode not in permitted_4_channel_modes: raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'RGBA' + mode = "RGBA" else: - permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] + permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"] if mode is not None and mode not in permitted_3_channel_modes: raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'RGB' + mode = "RGB" if mode is None: - raise TypeError('Input type {} is not supported'.format(npimg.dtype)) + raise TypeError("Input type {} is not supported".format(npimg.dtype)) return Image.fromarray(npimg, mode=mode) @@ -311,14 +307,16 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool Tensor: Normalized Tensor image. """ if not isinstance(tensor, torch.Tensor): - raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) + raise TypeError("Input tensor should be a torch tensor. Got {}.".format(type(tensor))) if not tensor.is_floating_point(): - raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) + raise TypeError("Input tensor should be a float tensor. Got {}.".format(tensor.dtype)) if tensor.ndim < 3: - raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' - '{}.'.format(tensor.size())) + raise ValueError( + "Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = " + "{}.".format(tensor.size()) + ) if not inplace: tensor = tensor.clone() @@ -327,7 +325,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) if (std == 0).any(): - raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) + raise ValueError("std evaluated to zero after conversion to {}, leading to division by zero.".format(dtype)) if mean.ndim == 1: mean = mean.view(-1, 1, 1) if std.ndim == 1: @@ -336,8 +334,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor: +def resize( + img: Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -396,9 +399,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: - warnings.warn( - "Anti-alias option is always applied for PIL Image input. Argument antialias is ignored." - ) + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) @@ -406,8 +407,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte def scale(*args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") + warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.") return resize(*args, **kwargs) @@ -515,14 +515,19 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: if crop_width == image_width and crop_height == image_height: return img - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) return crop(img, crop_top, crop_left, crop_height, crop_width) def resized_crop( - img: Tensor, top: int, left: int, height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR + img: Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> Tensor: """Crop the given image and resize it to desired size. If the image is torch Tensor, it is expected @@ -569,9 +574,7 @@ def hflip(img: Tensor) -> Tensor: return F_t.hflip(img) -def _get_perspective_coeffs( - startpoints: List[List[int]], endpoints: List[List[int]] -) -> List[float]: +def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]: """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. In Perspective Transform each pixel (x, y) in the original image gets transformed as, @@ -593,18 +596,18 @@ def _get_perspective_coeffs( a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8) - res = torch.linalg.lstsq(a_matrix, b_matrix, driver='gels').solution + res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution output: List[float] = res.tolist() return output def perspective( - img: Tensor, - startpoints: List[List[int]], - endpoints: List[List[int]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[List[float]] = None + img: Tensor, + startpoints: List[List[int]], + endpoints: List[List[int]], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, ) -> Tensor: """Perform perspective transform of the given image. If the image is torch Tensor, it is expected @@ -880,7 +883,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _get_inverse_affine_matrix( - center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -930,9 +933,13 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, - expand: bool = False, center: Optional[List[int]] = None, - fill: Optional[List[float]] = None, resample: Optional[int] = None + img: Tensor, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, ) -> Tensor: """Rotate the image by angle. If the image is torch Tensor, it is expected @@ -1004,9 +1011,15 @@ def rotate( def affine( - img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, - resample: Optional[int] = None, fillcolor: Optional[List[float]] = None + img: Tensor, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. If the image is torch Tensor, it is expected @@ -1053,9 +1066,7 @@ def affine( interpolation = _interpolation_modes_from_int(interpolation) if fillcolor is not None: - warnings.warn( - "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" - ) + warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead") fill = fillcolor if not isinstance(angle, (int, float)): @@ -1156,7 +1167,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: - """ Erase the input Tensor Image with given value. + """Erase the input Tensor Image with given value. This transform does not support PIL Image. Args: @@ -1172,12 +1183,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool Tensor Image: Erased image. """ if not isinstance(img, torch.Tensor): - raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + raise TypeError("img should be Tensor Image. Got {}".format(type(img))) if not inplace: img = img.clone() - img[..., i:i + h, j:j + w] = v + img[..., i : i + h, j : j + w] = v return img @@ -1208,34 +1219,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa PIL Image or Tensor: Gaussian Blurred version of the image. """ if not isinstance(kernel_size, (int, list, tuple)): - raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) + raise TypeError("kernel_size should be int or a sequence of integers. Got {}".format(type(kernel_size))) if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if len(kernel_size) != 2: - raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) + raise ValueError("If kernel_size is a sequence its length should be 2. Got {}".format(len(kernel_size))) for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: - raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) + raise ValueError("kernel_size should have odd and positive integers. Got {}".format(kernel_size)) if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): - raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) + raise TypeError("sigma should be either float or sequence of floats. Got {}".format(type(sigma))) if isinstance(sigma, (int, float)): sigma = [float(sigma), float(sigma)] if isinstance(sigma, (list, tuple)) and len(sigma) == 1: sigma = [sigma[0], sigma[0]] if len(sigma) != 2: - raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) + raise ValueError("If sigma is a sequence, its length should be 2. Got {}".format(len(sigma))) for s in sigma: - if s <= 0.: - raise ValueError('sigma should have positive values. Got {}'.format(sigma)) + if s <= 0.0: + raise ValueError("sigma should have positive values. Got {}".format(sigma)) t_img = img if not isinstance(img, torch.Tensor): if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image or Tensor. Got {}".format(type(img))) t_img = to_tensor(img) @@ -1278,7 +1289,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: PIL Image or Tensor: Posterized image. """ if not (0 <= bits <= 8): - raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) + raise ValueError("The number if bits should be between 0 and 8. Got {}".format(bits)) if not isinstance(img, torch.Tensor): return F_pil.posterize(img, bits) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3829637fdb7..d987d5a0ef7 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -2,8 +2,9 @@ from typing import Any, List, Sequence import numpy as np +from PIL import Image, ImageEnhance, ImageOps + import torch -from PIL import Image, ImageOps, ImageEnhance try: import accimage @@ -29,14 +30,14 @@ def _get_image_size(img: Any) -> List[int]: @torch.jit.unused def _get_image_num_channels(img: Any) -> int: if _is_pil_image(img): - return 1 if img.mode == 'L' else 3 + return 1 if img.mode == "L" else 3 raise TypeError("Unexpected type {}".format(type(img))) @torch.jit.unused def hflip(img): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.transpose(Image.FLIP_LEFT_RIGHT) @@ -44,7 +45,7 @@ def hflip(img): @torch.jit.unused def vflip(img): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.transpose(Image.FLIP_TOP_BOTTOM) @@ -52,7 +53,7 @@ def vflip(img): @torch.jit.unused def adjust_brightness(img, brightness_factor): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(brightness_factor) @@ -62,7 +63,7 @@ def adjust_brightness(img, brightness_factor): @torch.jit.unused def adjust_contrast(img, contrast_factor): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Contrast(img) img = enhancer.enhance(contrast_factor) @@ -72,7 +73,7 @@ def adjust_contrast(img, contrast_factor): @torch.jit.unused def adjust_saturation(img, saturation_factor): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Color(img) img = enhancer.enhance(saturation_factor) @@ -81,39 +82,39 @@ def adjust_saturation(img, saturation_factor): @torch.jit.unused def adjust_hue(img, hue_factor): - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)) if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: + if input_mode in {"L", "1", "I", "F"}: return img - h, s, v = img.convert('HSV').split() + h, s, v = img.convert("HSV").split() np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): + with np.errstate(over="ignore"): np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') + h = Image.fromarray(np_h, "L") - img = Image.merge('HSV', (h, s, v)).convert(input_mode) + img = Image.merge("HSV", (h, s, v)).convert(input_mode) return img @torch.jit.unused def adjust_gamma(img, gamma, gain=1): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') + raise ValueError("Gamma should be a non-negative real number") input_mode = img.mode - img = img.convert('RGB') - gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.convert("RGB") + gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma) for ele in range(256)] * 3 img = img.point(gamma_map) # use PIL's point-function to accelerate this part img = img.convert(input_mode) @@ -136,8 +137,9 @@ def pad(img, padding, fill=0, padding_mode="constant"): padding = tuple(padding) if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) if isinstance(padding, tuple) and len(padding) == 1: # Compatibility with `functional_tensor.pad` @@ -176,7 +178,7 @@ def pad(img, padding, fill=0, padding_mode="constant"): pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) - if img.mode == 'P': + if img.mode == "P": palette = img.getpalette() img = np.asarray(img) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) @@ -198,7 +200,7 @@ def pad(img, padding, fill=0, padding_mode="constant"): @torch.jit.unused def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.crop((left, top, left + width, top + height)) @@ -206,9 +208,9 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag @torch.jit.unused def resize(img, size, interpolation=Image.BILINEAR, max_size=None): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): - raise TypeError('Got inappropriate size arg: {}'.format(size)) + raise TypeError("Got inappropriate size arg: {}".format(size)) if isinstance(size, Sequence) and len(size) == 1: size = size[0] @@ -251,8 +253,7 @@ def _parse_fill(fill, img, name="fillcolor"): fill = tuple([fill] * num_bands) if isinstance(fill, (list, tuple)): if len(fill) != num_bands: - msg = ("The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") + msg = "The number of elements in 'fill' does not match the number of " "bands of the image ({} != {})" raise ValueError(msg.format(len(fill), num_bands)) fill = tuple(fill) @@ -263,7 +264,7 @@ def _parse_fill(fill, img, name="fillcolor"): @torch.jit.unused def affine(img, matrix, interpolation=0, fill=None): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) output_size = img.size opts = _parse_fill(fill, img) @@ -282,7 +283,7 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None): @torch.jit.unused def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) opts = _parse_fill(fill, img) @@ -292,17 +293,17 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None) @torch.jit.unused def to_grayscale(img, num_output_channels): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if num_output_channels == 1: - img = img.convert('L') + img = img.convert("L") elif num_output_channels == 3: - img = img.convert('L') + img = img.convert("L") np_img = np.array(img, dtype=np.uint8) np_img = np.dstack([np_img, np_img, np_img]) - img = Image.fromarray(np_img, 'RGB') + img = Image.fromarray(np_img, "RGB") else: - raise ValueError('num_output_channels should be either 1 or 3') + raise ValueError("num_output_channels should be either 1 or 3") return img @@ -310,28 +311,28 @@ def to_grayscale(img, num_output_channels): @torch.jit.unused def invert(img): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.invert(img) @torch.jit.unused def posterize(img, bits): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.posterize(img, bits) @torch.jit.unused def solarize(img, threshold): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.solarize(img, threshold) @torch.jit.unused def adjust_sharpness(img, sharpness_factor): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Sharpness(img) img = enhancer.enhance(sharpness_factor) @@ -341,12 +342,12 @@ def adjust_sharpness(img, sharpness_factor): @torch.jit.unused def autocontrast(img): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.autocontrast(img) @torch.jit.unused def equalize(img): if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.equalize(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a0e32d4237e..6de01c99beb 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,10 +1,10 @@ import warnings +from typing import List, Optional, Tuple import torch from torch import Tensor -from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad from torch.jit.annotations import BroadcastingList2 -from typing import Optional, Tuple, List +from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad def _is_tensor_a_torch_image(x: Tensor) -> bool: @@ -97,7 +97,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # factor should be forced to int for torch jit script # otherwise factor is a float and image // factor can produce different results factor = int((input_max + 1) // (output_max + 1)) - image = torch.div(image, factor, rounding_mode='floor') + image = torch.div(image, factor, rounding_mode="floor") return image.to(dtype) else: # factor should be forced to int for torch jit script @@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: if left < 0 or top < 0 or right > w or bottom > h: padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] - return pad(img[..., max(top, 0):bottom, max(left, 0):right], padding_ltrb, fill=0) + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) return img[..., top:bottom, left:right] @@ -138,7 +138,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: _assert_channels(img, [3]) if num_output_channels not in (1, 3): - raise ValueError('num_output_channels should be either 1 or 3') + raise ValueError("num_output_channels should be either 1 or 3") r, g, b = img.unbind(dim=-3) # This implementation closely follows the TF one: @@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if brightness_factor < 0: - raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor)) + raise ValueError("brightness_factor ({}) is not non-negative.".format(brightness_factor)) _assert_image_tensor(img) @@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: if contrast_factor < 0: - raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor)) + raise ValueError("contrast_factor ({}) is not non-negative.".format(contrast_factor)) _assert_image_tensor(img) @@ -179,10 +179,10 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: if not (-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)) if not (isinstance(img, torch.Tensor)): - raise TypeError('Input img should be Tensor image') + raise TypeError("Input img should be Tensor image") _assert_image_tensor(img) @@ -208,7 +208,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: if saturation_factor < 0: - raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor)) + raise ValueError("saturation_factor ({}) is not non-negative.".format(saturation_factor)) _assert_image_tensor(img) @@ -219,12 +219,12 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: if not isinstance(img, torch.Tensor): - raise TypeError('Input img should be a Tensor.') + raise TypeError("Input img should be a Tensor.") _assert_channels(img, [1, 3]) if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') + raise ValueError("Gamma should be a non-negative real number") result = img dtype = img.dtype @@ -238,11 +238,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.center_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.center_crop`` instead." ) _assert_image_tensor(img) @@ -262,11 +260,9 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.five_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.five_crop`` instead." ) _assert_image_tensor(img) @@ -289,11 +285,9 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.ten_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.ten_crop`` instead." ) _assert_image_tensor(img) @@ -351,7 +345,7 @@ def _rgb2hsv(img): hr = (maxc == r) * (bc - gc) hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) - h = (hr + hg + hb) + h = hr + hg + hb h = torch.fmod((h / 6.0 + 1.0), 1.0) return torch.stack((h, s, maxc), dim=-3) @@ -383,7 +377,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: # crop if needed if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding] - img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:img.shape[-1] - crop_right] + img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] padding = [max(x, 0) for x in padding] in_sizes = img.size() @@ -421,8 +415,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") @@ -482,7 +477,7 @@ def resize( size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None, - antialias: Optional[bool] = None + antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) @@ -499,8 +494,9 @@ def resize( if isinstance(size, list): if len(size) not in [1, 2]: - raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " - "{} element tuple/list".format(len(size))) + raise ValueError( + "Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size)) + ) if max_size is not None and len(size) != 1: raise ValueError( "max_size should only be passed if size specifies the length of the smaller edge, " @@ -560,12 +556,12 @@ def resize( def _assert_grid_transform_inputs( - img: Tensor, - matrix: Optional[List[float]], - interpolation: str, - fill: Optional[List[float]], - supported_interpolation_modes: List[str], - coeffs: Optional[List[float]] = None, + img: Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: Optional[List[float]], + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, ): if not (isinstance(img, torch.Tensor)): @@ -588,8 +584,10 @@ def _assert_grid_transform_inputs( # Check fill num_channels = _get_image_num_channels(img) if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): - msg = ("The number of elements in 'fill' cannot broadcast to match the number of " - "channels of the image ({} != {})") + msg = ( + "The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})" + ) raise ValueError(msg.format(len(fill), num_channels)) if interpolation not in supported_interpolation_modes: @@ -627,7 +625,12 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ]) + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + grid.dtype, + ], + ) if img.shape[0] > 1: # Apply same grid to a batch of images @@ -647,7 +650,7 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L mask = mask.expand_as(img) len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) - if mode == 'nearest': + if mode == "nearest": mask = mask < 0.5 img[mask] = fill_img[mask] else: # 'bilinear' @@ -658,7 +661,11 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L def _gen_affine_grid( - theta: Tensor, w: int, h: int, ow: int, oh: int, + theta: Tensor, + w: int, + h: int, + ow: int, + oh: int, ) -> Tensor: # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ # AffineGridGenerator.cpp#L18 @@ -680,7 +687,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) @@ -698,12 +705,14 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - pts = torch.tensor([ - [-0.5 * w, -0.5 * h, 1.0], - [-0.5 * w, 0.5 * h, 1.0], - [0.5 * w, 0.5 * h, 1.0], - [0.5 * w, -0.5 * h, 1.0], - ]) + pts = torch.tensor( + [ + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], + ] + ) theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) @@ -718,8 +727,11 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( - img: Tensor, matrix: List[float], interpolation: str = "nearest", - expand: bool = False, fill: Optional[List[float]] = None + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + expand: bool = False, + fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2] @@ -740,14 +752,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # - theta1 = torch.tensor([[ - [coeffs[0], coeffs[1], coeffs[2]], - [coeffs[3], coeffs[4], coeffs[5]] - ]], dtype=dtype, device=device) - theta2 = torch.tensor([[ - [coeffs[6], coeffs[7], 1.0], - [coeffs[6], coeffs[7], 1.0] - ]], dtype=dtype, device=device) + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) @@ -769,7 +777,7 @@ def perspective( img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None ) -> Tensor: if not (isinstance(img, torch.Tensor)): - raise TypeError('Input img should be Tensor.') + raise TypeError("Input img should be Tensor.") _assert_image_tensor(img) @@ -779,7 +787,7 @@ def perspective( interpolation=interpolation, fill=fill, supported_interpolation_modes=["nearest", "bilinear"], - coeffs=perspective_coeffs + coeffs=perspective_coeffs, ) ow, oh = img.shape[-1], img.shape[-2] @@ -799,7 +807,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: def _get_gaussian_kernel2d( - kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device ) -> Tensor: kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) @@ -809,7 +817,7 @@ def _get_gaussian_kernel2d( def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: if not (isinstance(img, torch.Tensor)): - raise TypeError('img should be Tensor. Got {}'.format(type(img))) + raise TypeError("img should be Tensor. Got {}".format(type(img))) _assert_image_tensor(img) @@ -817,7 +825,12 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + kernel.dtype, + ], + ) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] @@ -851,7 +864,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) _assert_channels(img, [1, 3]) - mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) return img & mask @@ -876,7 +889,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: kernel /= kernel.sum() kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + kernel.dtype, + ], + ) result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) @@ -888,7 +906,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: if sharpness_factor < 0: - raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) + raise ValueError("sharpness_factor ({}) is not non-negative.".format(sharpness_factor)) _assert_image_tensor(img) @@ -933,13 +951,11 @@ def _scale_channel(img_chan): hist = torch.bincount(img_chan.view(-1), minlength=256) nonzero_hist = hist[hist != 0] - step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor') + step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") if step == 0: return img_chan - lut = torch.div( - torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), - step, rounding_mode='floor') + lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) return lut[img_chan.to(torch.int64)].to(torch.uint8) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 954d5f5f064..4ab83a69e92 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,26 +3,59 @@ import random import warnings from collections.abc import Sequence -from typing import Tuple, List, Optional +from typing import List, Optional, Tuple import torch from torch import Tensor +from . import functional as F +from .functional import InterpolationMode, _interpolation_modes_from_int + try: import accimage except ImportError: accimage = None -from . import functional as F -from .functional import InterpolationMode, _interpolation_modes_from_int - -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", - "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", - "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", - "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] +__all__ = [ + "Compose", + "ToTensor", + "PILToTensor", + "ConvertImageDtype", + "ToPILImage", + "Normalize", + "Resize", + "Scale", + "CenterCrop", + "Pad", + "Lambda", + "RandomApply", + "RandomChoice", + "RandomOrder", + "RandomCrop", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomResizedCrop", + "RandomSizedCrop", + "FiveCrop", + "TenCrop", + "LinearTransformation", + "ColorJitter", + "RandomRotation", + "RandomAffine", + "Grayscale", + "RandomGrayscale", + "RandomPerspective", + "RandomErasing", + "GaussianBlur", + "InterpolationMode", + "RandomInvert", + "RandomPosterize", + "RandomSolarize", + "RandomAdjustSharpness", + "RandomAutocontrast", + "RandomEqualize", +] class Compose: @@ -61,11 +94,11 @@ def __call__(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string @@ -97,7 +130,7 @@ def __call__(self, pic): return F.to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class PILToTensor: @@ -117,7 +150,7 @@ def __call__(self, pic): return F.pil_to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class ConvertImageDtype(torch.nn.Module): @@ -164,6 +197,7 @@ class ToPILImage: .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ + def __init__(self, mode=None): self.mode = mode @@ -179,10 +213,10 @@ def __call__(self, pic): return F.to_pil_image(pic, self.mode) def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" if self.mode is not None: - format_string += 'mode={0}'.format(self.mode) - format_string += ')' + format_string += "mode={0}".format(self.mode) + format_string += ")" return format_string @@ -221,7 +255,7 @@ def forward(self, tensor: Tensor) -> Tensor: return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std) class Resize(torch.nn.Module): @@ -300,17 +334,20 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( - self.size, interpolate_str, self.max_size, self.antialias) + return self.__class__.__name__ + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format( + self.size, interpolate_str, self.max_size, self.antialias + ) class Scale(Resize): """ Note: This transform is deprecated in favor of Resize. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") + warnings.warn( + "The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead." + ) super(Scale, self).__init__(*args, **kwargs) @@ -341,7 +378,7 @@ def forward(self, img): return F.center_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class Pad(torch.nn.Module): @@ -394,8 +431,9 @@ def __init__(self, padding, fill=0, padding_mode="constant"): raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) self.padding = padding self.fill = fill @@ -412,8 +450,9 @@ def forward(self, img): return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ - format(self.padding, self.fill, self.padding_mode) + return self.__class__.__name__ + "(padding={0}, fill={1}, padding_mode={2})".format( + self.padding, self.fill, self.padding_mode + ) class Lambda: @@ -432,7 +471,7 @@ def __call__(self, img): return self.lambd(img) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class RandomTransforms: @@ -451,11 +490,11 @@ def __call__(self, *args, **kwargs): raise NotImplementedError() def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string @@ -492,18 +531,18 @@ def forward(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += '\n p={}'.format(self.p) + format_string = self.__class__.__name__ + "(" + format_string += "\n p={}".format(self.p) for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript. - """ + """Apply a list of transformations in a random order. This transform does not support torchscript.""" + def __call__(self, img): order = list(range(len(self.transforms))) random.shuffle(order) @@ -513,8 +552,8 @@ def __call__(self, img): class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript. - """ + """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" + def __call__(self, img): t = random.choice(self.transforms) return t(img) @@ -579,23 +618,19 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int th, tw = output_size if h + 1 < th or w + 1 < tw: - raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) - ) + raise ValueError("Required crop size {} is larger then input image size {}".format((th, tw), (h, w))) if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() - self.size = tuple(_setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - )) + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.padding = padding self.pad_if_needed = pad_if_needed @@ -658,7 +693,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomVerticalFlip(torch.nn.Module): @@ -688,7 +723,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomPerspective(torch.nn.Module): @@ -768,27 +803,27 @@ def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[L half_height = height // 2 half_width = width // 2 topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomResizedCrop(torch.nn.Module): @@ -820,7 +855,7 @@ class RandomResizedCrop(torch.nn.Module): """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -844,9 +879,7 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat self.ratio = ratio @staticmethod - def get_params( - img: Tensor, scale: List[float], ratio: List[float] - ) -> Tuple[int, int, int, int]: + def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: @@ -864,9 +897,7 @@ def get_params( log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) @@ -904,10 +935,10 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) - format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) - format_string += ', interpolation={0})'.format(interpolate_str) + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) return format_string @@ -915,9 +946,12 @@ class RandomSizedCrop(RandomResizedCrop): """ Note: This transform is deprecated in favor of RandomResizedCrop. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + - "please use transforms.RandomResizedCrop instead.") + warnings.warn( + "The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead." + ) super(RandomSizedCrop, self).__init__(*args, **kwargs) @@ -964,7 +998,7 @@ def forward(self, img): return F.five_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class TenCrop(torch.nn.Module): @@ -1013,7 +1047,7 @@ def forward(self, img): return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): - return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format(self.size, self.vertical_flip) class LinearTransformation(torch.nn.Module): @@ -1038,17 +1072,25 @@ class LinearTransformation(torch.nn.Module): def __init__(self, transformation_matrix, mean_vector): super().__init__() if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + raise ValueError( + "transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()) + ) if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + - " as any one of the dimensions of the transformation_matrix [{}]" - .format(tuple(transformation_matrix.size()))) + raise ValueError( + "mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]".format( + tuple(transformation_matrix.size()) + ) + ) if transformation_matrix.device != mean_vector.device: - raise ValueError("Input tensors should be on the same device. Got {} and {}" - .format(transformation_matrix.device, mean_vector.device)) + raise ValueError( + "Input tensors should be on the same device. Got {} and {}".format( + transformation_matrix.device, mean_vector.device + ) + ) self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1064,13 +1106,17 @@ def forward(self, tensor: Tensor) -> Tensor: shape = tensor.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0]) + ) if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device) + ) flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) @@ -1078,9 +1124,9 @@ def forward(self, tensor: Tensor) -> Tensor: return tensor def __repr__(self): - format_string = self.__class__.__name__ + '(transformation_matrix=' - format_string += (str(self.transformation_matrix.tolist()) + ')') - format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + format_string = self.__class__.__name__ + "(transformation_matrix=" + format_string += str(self.transformation_matrix.tolist()) + ")" + format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")" return format_string @@ -1107,14 +1153,13 @@ class ColorJitter(torch.nn.Module): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() - self.brightness = self._check_input(brightness, 'brightness') - self.contrast = self._check_input(contrast, 'contrast') - self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError("If {} is a single number, it must be non negative.".format(name)) @@ -1134,11 +1179,12 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs return value @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + def get_params( + brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]], + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1172,8 +1218,9 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1188,11 +1235,11 @@ def forward(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += 'brightness={0}'.format(self.brightness) - format_string += ', contrast={0}'.format(self.contrast) - format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) + format_string = self.__class__.__name__ + "(" + format_string += "brightness={0}".format(self.brightness) + format_string += ", contrast={0}".format(self.contrast) + format_string += ", saturation={0}".format(self.saturation) + format_string += ", hue={0})".format(self.hue) return format_string @@ -1242,10 +1289,10 @@ def __init__( ) interpolation = _interpolation_modes_from_int(interpolation) - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if center is not None: - _check_sequence_input(center, "center", req_sizes=(2, )) + _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center @@ -1289,14 +1336,14 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) - format_string += ', interpolation={0}'.format(interpolate_str) - format_string += ', expand={0}'.format(self.expand) + format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees) + format_string += ", interpolation={0}".format(interpolate_str) + format_string += ", expand={0}".format(self.expand) if self.center is not None: - format_string += ', center={0}'.format(self.center) + format_string += ", center={0}".format(self.center) if self.fill is not None: - format_string += ', fill={0}'.format(self.fill) - format_string += ')' + format_string += ", fill={0}".format(self.fill) + format_string += ")" return format_string @@ -1337,8 +1384,15 @@ class RandomAffine(torch.nn.Module): """ def __init__( - self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, - fillcolor=None, resample=None + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + fillcolor=None, + resample=None, ): super().__init__() if resample is not None: @@ -1361,17 +1415,17 @@ def __init__( ) fill = fillcolor - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: - _check_sequence_input(translate, "translate", req_sizes=(2, )) + _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: - _check_sequence_input(scale, "scale", req_sizes=(2, )) + _check_sequence_input(scale, "scale", req_sizes=(2,)) for s in scale: if s <= 0: raise ValueError("scale values should be positive") @@ -1393,11 +1447,11 @@ def __init__( @staticmethod def get_params( - degrees: List[float], - translate: Optional[List[float]], - scale_ranges: Optional[List[float]], - shears: Optional[List[float]], - img_size: List[int] + degrees: List[float], + translate: Optional[List[float]], + scale_ranges: Optional[List[float]], + shears: Optional[List[float]], + img_size: List[int], ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: """Get parameters for affine transformation @@ -1450,20 +1504,20 @@ def forward(self, img): return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) def __repr__(self): - s = '{name}(degrees={degrees}' + s = "{name}(degrees={degrees}" if self.translate is not None: - s += ', translate={translate}' + s += ", translate={translate}" if self.scale is not None: - s += ', scale={scale}' + s += ", scale={scale}" if self.shear is not None: - s += ', shear={shear}' + s += ", shear={shear}" if self.interpolation != InterpolationMode.NEAREST: - s += ', interpolation={interpolation}' + s += ", interpolation={interpolation}" if self.fill != 0: - s += ', fill={fill}' - s += ')' + s += ", fill={fill}" + s += ")" d = dict(self.__dict__) - d['interpolation'] = self.interpolation.value + d["interpolation"] = self.interpolation.value return s.format(name=self.__class__.__name__, **d) @@ -1498,7 +1552,7 @@ def forward(self, img): return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self): - return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + return self.__class__.__name__ + "(num_output_channels={0})".format(self.num_output_channels) class RandomGrayscale(torch.nn.Module): @@ -1535,11 +1589,11 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={0})'.format(self.p) + return self.__class__.__name__ + "(p={0})".format(self.p) class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. + """Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 @@ -1590,7 +1644,7 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace @staticmethod def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None ) -> Tuple[int, int, int, int, Tensor]: """Get parameters for ``erase`` for a random erasing. @@ -1611,9 +1665,7 @@ def get_params( log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() h = int(round(math.sqrt(erase_area * aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio))) @@ -1625,8 +1677,8 @@ def get_params( else: v = torch.tensor(value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1, )).item() - j = torch.randint(0, img_w - w + 1, size=(1, )).item() + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() return i, j, h, w, v # Return original image @@ -1644,7 +1696,9 @@ def forward(self, img): # cast self.value to script acceptable type if isinstance(self.value, (int, float)): - value = [self.value, ] + value = [ + self.value, + ] elif isinstance(self.value, str): value = None elif isinstance(self.value, tuple): @@ -1663,11 +1717,11 @@ def forward(self, img): return img def __repr__(self): - s = '(p={}, '.format(self.p) - s += 'scale={}, '.format(self.scale) - s += 'ratio={}, '.format(self.ratio) - s += 'value={}, '.format(self.value) - s += 'inplace={})'.format(self.inplace) + s = "(p={}, ".format(self.p) + s += "scale={}, ".format(self.scale) + s += "ratio={}, ".format(self.ratio) + s += "value={}, ".format(self.value) + s += "inplace={})".format(self.inplace) return self.__class__.__name__ + s @@ -1700,7 +1754,7 @@ def __init__(self, kernel_size, sigma=(0.1, 2.0)): raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: + if not 0.0 < sigma[0] <= sigma[1]: raise ValueError("sigma values should be positive and of the form (min, max).") else: raise ValueError("sigma should be a single number or a list/tuple with length 2.") @@ -1732,8 +1786,8 @@ def forward(self, img: Tensor) -> Tensor: return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self): - s = '(kernel_size={}, '.format(self.kernel_size) - s += 'sigma={})'.format(self.sigma) + s = "(kernel_size={}, ".format(self.kernel_size) + s += "sigma={})".format(self.sigma) return self.__class__.__name__ + s @@ -1758,7 +1812,7 @@ def _check_sequence_input(x, name, req_sizes): raise ValueError("{} should be sequence of length {}.".format(name, msg)) -def _setup_angle(x, name, req_sizes=(2, )): +def _setup_angle(x, name, req_sizes=(2,)): if isinstance(x, numbers.Number): if x < 0: raise ValueError("If {} is a single number, it must be positive.".format(name)) @@ -1796,7 +1850,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomPosterize(torch.nn.Module): @@ -1828,7 +1882,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p) class RandomSolarize(torch.nn.Module): @@ -1860,7 +1914,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + return self.__class__.__name__ + "(threshold={},p={})".format(self.threshold, self.p) class RandomAdjustSharpness(torch.nn.Module): @@ -1892,7 +1946,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + return self.__class__.__name__ + "(sharpness_factor={},p={})".format(self.sharpness_factor, self.p) class RandomAutocontrast(torch.nn.Module): @@ -1922,7 +1976,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomEqualize(torch.nn.Module): @@ -1952,4 +2006,4 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) diff --git a/torchvision/utils.py b/torchvision/utils.py index 494661e6ad8..ca8ad0c1073 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -1,10 +1,12 @@ -from typing import Union, Optional, List, Tuple, Text, BinaryIO -import pathlib -import torch import math +import pathlib import warnings +from typing import BinaryIO, List, Optional, Text, Tuple, Union + import numpy as np -from PIL import Image, ImageDraw, ImageFont, ImageColor +from PIL import Image, ImageColor, ImageDraw, ImageFont + +import torch __all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] @@ -18,7 +20,7 @@ def make_grid( value_range: Optional[Tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, - **kwargs + **kwargs, ) -> torch.Tensor: """ Make a grid of images. @@ -41,9 +43,8 @@ def make_grid( Returns: grid (Tensor): the tensor containing grid of images. """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") if "range" in kwargs.keys(): warning = "range will be deprecated, please use value_range instead." @@ -67,8 +68,9 @@ def make_grid( if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place if value_range is not None: - assert isinstance(value_range, tuple), \ - "value_range has to be a tuple (min, max) if specified. min and max are numbers" + assert isinstance( + value_range, tuple + ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, low, high): img.clamp_(min=low, max=high) @@ -115,7 +117,7 @@ def save_image( tensor: Union[torch.Tensor, List[torch.Tensor]], fp: Union[Text, pathlib.Path, BinaryIO], format: Optional[str] = None, - **kwargs + **kwargs, ) -> None: """ Save a given Tensor into an image file. @@ -131,7 +133,7 @@ def save_image( grid = make_grid(tensor, **kwargs) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer - ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(fp, format=format) @@ -145,7 +147,7 @@ def draw_bounding_boxes( fill: Optional[bool] = False, width: int = 1, font: Optional[str] = None, - font_size: int = 10 + font_size: int = 10, ) -> torch.Tensor: """