diff --git a/.github/actions/run_performance_tests/action.yml b/.github/actions/run_performance_tests/action.yml deleted file mode 100644 index 840d03b6..00000000 --- a/.github/actions/run_performance_tests/action.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: run_performance_tests -description: MeshPy performance test suite -runs: - using: composite - steps: - - name: MeshPy testing - shell: bash - env: - CUBIT_ROOT: /imcs/public/compsim/opt/cubit-15.2 - PERFORMANCE_TESTING_HOST: github-sisyphos-docker - run: | - cd ${GITHUB_WORKSPACE} - source python-testing-environment/bin/activate - pip install .[CI-CD] - python --version - pip list - cd tests - python performance_testing.py diff --git a/.github/actions/run_tests/action.yml b/.github/actions/run_tests/action.yml index f6556b5d..7d452629 100644 --- a/.github/actions/run_tests/action.yml +++ b/.github/actions/run_tests/action.yml @@ -9,22 +9,10 @@ inputs: description: Command to source the virtual environment required: false default: "" - require_four_c: - description: Fail if the 4C tests can not be performed + additional-pytest-flags: + description: Additional flags to pass to pytest, i.e., markers required: false - default: 1 - require_arborx: - description: Fail if the ArborX tests can not be performed - required: false - default: 1 - require_cubitpy: - description: Fail if the CubitPy tests can not be performed - required: false - default: 1 - coverage_config: - description: Config file to use for coverage analysis - required: false - default: "coverage.config" + default: "" runs: using: composite steps: @@ -32,22 +20,16 @@ runs: shell: bash env: MESHPY_FOUR_C_EXE: /home/user/4C/build/4C - TESTING_GITHUB: 1 - TESTING_GITHUB_4C: ${{ inputs.require_four_c }} - TESTING_GITHUB_ARBORX: ${{ inputs.require_arborx }} - TESTING_GITHUB_CUBITPY: ${{ inputs.require_cubitpy }} CUBIT_ROOT: /imcs/public/compsim/opt/cubit-15.2 OMPI_MCA_rmaps_base_oversubscribe: 1 + PERFORMANCE_TESTING_HOST: github-sisyphos-docker run: | cd ${GITHUB_WORKSPACE} ${{ inputs.source-command }} pip install ${{ inputs.install-command }} python --version pip list - cd tests - coverage run --rcfile=${{ inputs.coverage_config }} testing_main.py - coverage html - coverage report - coverage-badge -o htmlcov/coverage.svg - coverage run --rcfile=${{ inputs.coverage_config }} -m pytest pytest_testing_cosserat_curve.py - coverage report + TEMP_DIR="${RUNNER_TEMP}/meshpy_pytest" + mkdir -p "$TEMP_DIR" + echo "PYTEST_TMPDIR=$TEMP_DIR" >> $GITHUB_ENV + pytest --basetemp="$TEMP_DIR" ${{ inputs.additional-pytest-flags}} diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2b4b45df..0dd766e6 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -47,16 +47,12 @@ jobs: python-version: ${{ matrix.python-version }} - name: Run the test suite uses: ./.github/actions/run_tests - with: - require_four_c: 0 - require_arborx: 0 - require_cubitpy: 0 - name: Upload test results on failure if: failure() uses: actions/upload-artifact@v4 with: name: ${{github.job}}-${{ matrix.os-version }}-python${{ matrix.python-version }}-${{github.run_number}} - path: ${{github.workspace}}/tests/testing-tmp/ + path: ${{ env.PYTEST_TMPDIR }} meshpy-testing-cubitpy: name: self-hosted with CubitPy @@ -75,14 +71,13 @@ jobs: uses: ./.github/actions/run_tests with: source-command: "source python-testing-environment/bin/activate" - require_four_c: 0 - require_arborx: 0 + additional-pytest-flags: "--CubitPy" - name: Upload test results on failure if: failure() uses: actions/upload-artifact@v4 with: name: ${{github.job}}-${{github.run_number}} - path: ${{github.workspace}}/tests/testing-tmp/ + path: ${{ env.PYTEST_TMPDIR }} meshpy-testing-4C-arborx: name: ubuntu-latest with 4C and ArborX @@ -106,14 +101,13 @@ jobs: with: source-command: "source python-testing-environment/bin/activate" install-command: "-e .[CI-CD]" - require_cubitpy: 0 - coverage_config: "coverage_local.config" + additional-pytest-flags: "--4C --ArborX" - name: Upload test results on failure if: failure() uses: actions/upload-artifact@v4 with: name: ${{github.job}}-${{github.run_number}} - path: ${{github.workspace}}/tests/testing-tmp/ + path: ${{ env.PYTEST_TMPDIR }} meshpy-performance-testing: name: performance tests @@ -129,5 +123,15 @@ jobs: uses: ./.github/actions/setup_virtual_python_environment with: python-exe: /home_local/github-runner/testing_lib/spack/opt/spack/linux-ubuntu20.04-icelake/gcc-9.4.0/python-3.12.1-qnjucxirxh534suwewl6drfa237u6t7w/bin/python - - name: Run the performance test suite - uses: ./.github/actions/run_performance_tests + - name: Run the test suite + uses: ./.github/actions/run_tests + with: + source-command: "source python-testing-environment/bin/activate" + install-command: ".[CI-CD]" + additional-pytest-flags: "--performance-tests --exclude-standard-tests" + - name: Upload test results on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: ${{github.job}}-${{github.run_number}} + path: ${{ env.PYTEST_TMPDIR }} diff --git a/README.md b/README.md index 65f64647..ad0d160a 100644 --- a/README.md +++ b/README.md @@ -86,16 +86,21 @@ tests with 4C export MESHPY_FOUR_C_EXE=path_to_4C ``` -To check if everything worked as expected, run the tests +To check if everything worked as expected, run the standard tests with ```bash -cd /tests -python testing_main.py +pytest ``` -Also run the performance tests (the reference time values and host name might have to be adapted in the file `/tests/performance_testing.py`) +Further tests can be added with the following flags: `--4C`, `--ArborX`, `--CubitPy`, `--performance-tests`. +These can be arbitrarily combined, for example ```bash -cd /tests -python performance_testing.py +pytest --4C --CubityPy +``` +executes the standard tests, the 4C tests and the CubitPy tests. Note that the reference time values for the performance tests might not suite your system. + +Finally, the base tests can be deactivated with `--exclude-standard-tests`. For example to just run the CubitPy tests execute +```bash +pytest --CubitPy --exclude-standard-tests ``` Before you are ready to contribute to MeshPy, please make sure to install the `pre-commit hook` within the python environment to follow our style guides: @@ -122,8 +127,8 @@ cd /build/geometric_search cmake ../../meshpy/geometric_search/src/ make -j4 ``` + If the ArborX extension is working correctly can be checked by running the geometric search tests ```bash -cd /tests -python testing_geometric_search.py +pytest --ArborX ``` diff --git a/pyproject.toml b/pyproject.toml index 23453b8e..0e7f2ff7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "numpy-quaternion", "pre-commit", "pytest", + "pytest-cov", "pyvista", "pyvista_utils@git+https://github.com/isteinbrecher/pyvista_utils.git@main", "scipy", @@ -55,3 +56,13 @@ CI-CD = [ "cubitpy@git+https://github.com/imcs-compsim/cubitpy.git@main", "setuptools" # Needed for coverage-badge ] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-p pytest_cov --cov-report=term --cov-report=html --cov-fail-under=0 --cov=meshpy/ --cov=tutorial/ --cov-append" +markers = [ + "fourc: tests in combination with 4C", + "arborx: tests in combination with ArborX", + "cubitpy: tests in combination with CubitPy", + "performance: performance tests of MeshPy" +] diff --git a/tests/__init__.py b/tests/__init__.py index df62647c..b709f810 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,4 +28,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ----------------------------------------------------------------------------- -"""This module defines testing functionality for MeshPy.""" +"""This module tests MeshPy.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..038d382e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- +# ----------------------------------------------------------------------------- +# MeshPy: A beam finite element input generator +# +# MIT License +# +# Copyright (c) 2018-2024 +# Ivo Steinbrecher +# Institute for Mathematics and Computer-Based Simulation +# Universitaet der Bundeswehr Muenchen +# https://www.unibw.de/imcs-en +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ----------------------------------------------------------------------------- +"""Testing framework infrastructure.""" + +import os +import shutil +import subprocess +from difflib import unified_diff +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import pytest +import vtk +from _pytest.config import Config +from _pytest.config.argparsing import Parser +from vtk_utils.compare_grids import compare_grids + +from meshpy import InputFile + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest. + + Args: + parser: Pytest parser + """ + + parser.addoption( + "--4C", + action="store_true", + default=False, + help="Execute standard and 4C based tests.", + ) + + parser.addoption( + "--ArborX", + action="store_true", + default=False, + help="Execute standard and ArborX based tests.", + ) + + parser.addoption( + "--CubitPy", + action="store_true", + default=False, + help="Execute standard and CubitPy based tests.", + ) + + parser.addoption( + "--performance-tests", + action="store_true", + default=False, + help="Execute standard and performance tests.", + ) + + parser.addoption( + "--exclude-standard-tests", + action="store_true", + default=False, + help="Exclude standard tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: list) -> None: + """Filter tests based on their markers and provided command line options. + + Currently configured options: + `pytest`: Execute standard tests with no markers + `pytest --4C`: Execute standard tests and tests with the `fourc` marker + `pytest --ArborX`: Execute standard tests and tests with the `arborx` marker + `pytest --CubitPy`: Execute standard tests and tests with the `cubitpy` marker + `pytest --performance-tests`: Execute standard tests and tests with the `performance` marker + `pytest --exclude-standard-tests`: Execute tests with any other marker and exclude the standard unmarked tests + + Args: + config: Pytest config + items: Pytest list of tests + """ + + selected_tests = [] + + # loop over all collected tests + for item in items: + # Get all set markers for current test (e.g. `fourc_arborx`, `cubitpy`, `performance`, ...) + markers = [marker.name for marker in item.iter_markers()] + + for flag, marker in zip( + ["--4C", "--ArborX", "--CubitPy", "--performance-tests"], + ["fourc", "arborx", "cubitpy", "performance"], + ): + if config.getoption(flag) and marker in markers: + selected_tests.append(item) + + if not markers and not config.getoption("--exclude-standard-tests"): + selected_tests.append(item) + + deselected_tests = list(set(items) - set(selected_tests)) + + items[:] = selected_tests + config.hook.pytest_deselected(items=deselected_tests) + + +@pytest.fixture(scope="session") +def reference_file_directory() -> Path: + """Provide the path to the reference file directory. + + Returns: + Path: A Path object representing the full path to the reference file directory. + """ + + testing_path = Path(__file__).resolve().parent + return testing_path / "reference-files" + + +@pytest.fixture(scope="function") +def current_test_name(request: pytest.FixtureRequest) -> str: + """Return the name of the current pytest test. + + Args: + request: The pytest request object. + + Returns: + str: The name of the current pytest test. + """ + + return request.node.name + + +@pytest.fixture(scope="function") +def get_string() -> Callable: + """Return function to get string from different types of input. + + Necessary to enable the function call through pytest fixtures. + + Returns: + Function to get string from file. + """ + + def _get_string( + data: Union[Path, str, InputFile], input_file_kwargs: dict = {} + ) -> str: + """Get string from file, string or InputFile. + + Args: + data: Object that should be converted to a string. + input_file_kwargs: Dictionary which contains the settings when extracting + the string from the input file. + + Returns: + String representation of data. + """ + + if isinstance(data, str): + return data + elif isinstance(data, Path): + if not data.is_file(): + raise FileNotFoundError(f"File {data} does not exist!") + with open(data, "r") as file: + return file.read() + elif isinstance(data, InputFile): + return data.get_string(**input_file_kwargs) + else: + raise TypeError(f"Data type {type(data)} not implemented.") + + return _get_string + + +@pytest.fixture(scope="function") +def get_corresponding_reference_file_path( + reference_file_directory, current_test_name +) -> Callable: + """Return function to get path to corresponding reference file for each + test. + + Necessary to enable the function call through pytest fixtures. + """ + + def _get_corresponding_reference_file_path( + reference_file_base_name: Optional[str] = None, + additional_identifier: Optional[str] = None, + extension: str = "dat", + ) -> Path: + """Get path to corresponding reference file for each test. Also check + if this file exists. Basename, additional identifier and extension can + be adjusted. + + Args: + reference_file_base_name: Basename of reference file, if none is + provided the current test name is utilized + additional_identifier: Additional identifier for reference file, by default none + extension: Extension of reference file, by default ".dat" + + Returns: + Path to reference file. + """ + + corresponding_reference_file = reference_file_base_name or current_test_name + + if additional_identifier: + corresponding_reference_file += f"_{additional_identifier}" + + corresponding_reference_file += "." + extension + + corresponding_reference_file_path = ( + reference_file_directory / corresponding_reference_file + ) + + assert os.path.isfile(corresponding_reference_file_path) + + return corresponding_reference_file_path + + return _get_corresponding_reference_file_path + + +@pytest.fixture(scope="function") +def assert_results_equal(get_string, tmp_path, current_test_name) -> Callable: + """Return function to compare either string or files. + + Necessary to enable the function call through pytest fixtures. + + Args: + get_string: Function to get string from file. + tmp_path: Temporary path for testing. + current_test_name: Name of the current test. + + Returns: + Function to compare results. + """ + + def _assert_results_equal( + reference: Union[Path, str], + result: Union[Path, str, InputFile], + rtol: Optional[float] = None, + atol: Optional[float] = None, + input_file_kwargs: dict = { + "check_nox": False, + "header": False, + }, + **kwargs, + ) -> None: + """Comparison between reference and result with relative or absolute + tolerance. + + If the comparison fails, an assertion is raised. + + Args: + reference: The reference data. + result: The result data. + rtol: The relative tolerance. + atol: The absolute tolerance. + input_file_kwargs: Dictionary which contains the settings when extracting + the string from the input file. + """ + + # Compare two universal files + if isinstance(reference, Path) and isinstance(result, Path): + if reference.suffix != result.suffix: + raise RuntimeError( + "Reference and result file must be of same file type!" + ) + elif reference.suffix in [".vtk", ".vtu"]: + compare_vtk_files(reference, result, rtol, atol) + else: + raise NotImplementedError( + f"Comparison is not yet implemented for {reference.suffix} files." + ) + + # String based comparison + else: + # retrieve strings to compare + [reference_string, result_string] = [ + get_string(data, input_file_kwargs) for data in [reference, result] + ] + + # compare strings and handle non-matching strings + try: + compare_strings(reference_string, result_string, rtol, atol, **kwargs) + except AssertionError as error: + if isinstance(reference, Path): + handle_unequal_strings( + tmp_path, current_test_name, result_string, reference + ) + raise AssertionError(str(error)) + + return _assert_results_equal + + +def compare_strings( + reference: str, + result: str, + rtol: Optional[float] = None, + atol: Optional[float] = None, + string_splitter: str = " ", +) -> None: + """Compare if two strings are identical, optionally within a given + tolerance. If the comparison fails, an error is raised. + + Args: + reference: The reference string. + result: The result string. + rtol: The relative tolerance. + atol: The absolute tolerance. + string_splitter: With which string the strings are split. + """ + + if rtol is None and atol is None: + compare_strings_equality_assert(reference, result) + else: + compare_strings_with_tolerance_assert( + reference, result, rtol, atol, string_splitter=string_splitter + ) + + +def compare_strings_equality_assert(reference: str, result: str) -> None: + """Check if two strings are exactly equal, if not raise an error. + + Args: + reference: The reference string. + result: The result string. + """ + diff = list(unified_diff(reference.splitlines(), result.splitlines(), lineterm="")) + if diff: + raise AssertionError( + "Exact string comparison failed! Difference between reference and result: \n".join( + list(diff) + ) + ) + + +def compare_strings_with_tolerance_assert( + reference: str, + result: str, + rtol: Optional[float], + atol: Optional[float], + string_splitter=" ", +) -> None: + """Compare if two strings are identical within a given tolerance. + + Args: + reference: The reference string. + result: The result string. + rtol: The relative tolerance. + atol: The absolute tolerance. + string_splitter: With which string the strings are split. + """ + + rtol = 0.0 if rtol is None else rtol + atol = 0.0 if atol is None else atol + + lines_reference = reference.strip().split("\n") + lines_result = result.strip().split("\n") + + if len(lines_reference) != len(lines_result): + raise AssertionError( + f"String comparison with tolerance failed!\n" + + f"Number of lines in reference and result differ: {len(lines_reference)} != {len(lines_result)}" + ) + + # Loop over each line in the file + for line_reference, line_result in zip(lines_reference, lines_result): + line_reference_splits = line_reference.strip().split(string_splitter) + line_result_splits = line_result.strip().split(string_splitter) + + if len(line_reference_splits) != len(line_result_splits): + raise AssertionError( + f"String comparison with tolerance failed!\n" + + f"Number of items in reference and result line differ!\n" + + f"Reference line: {line_reference}\n" + + f"Result line: {line_result}" + ) + + # Loop over each entry in the line + for item_reference, item_result in zip( + line_reference_splits, line_result_splits + ): + try: + number_reference = float(item_reference.strip()) + number_result = float(item_result.strip()) + if np.isclose(number_reference, number_result, rtol=rtol, atol=atol): + pass + else: + raise AssertionError( + f"String comparison with tolerance failed!\n" + + f"Numbers do not match within given tolerance!\n" + + f"Reference line: {line_reference}\n" + + f"Result line: {line_result}" + ) + + except ValueError: + if item_reference.strip() != item_result.strip(): + raise AssertionError( + f"String comparison with tolerance failed!\n" + + f"Strings do not match in line!\n" + + f"Reference line: {line_reference}\n" + + f"Result line: {line_result}" + ) + + +def compare_vtk_files( + reference: Path, result: Path, rtol: Optional[float], atol: Optional[float] +) -> None: + """Compare two VTK files for equality within a given tolerance. + + Args: + reference: The path to the reference VTK file. + result: The path to the result VTK file to be compared. + rtol: The relative tolerance parameter. + atol: The absolute tolerance parameter. + """ + + compare = compare_grids( + get_vtk(reference), get_vtk(result), output=True, rtol=rtol, atol=atol + ) + + if not compare[0]: + raise AssertionError("\n".join(compare[1])) + + +def get_vtk(path: Path) -> vtk.vtkDataObject: + """Return vtk data object for given vtk file. + + Args: + path: Path to .vtu/.vtk file. + + Returns: + vtk.vtkDataObject: VTK data object. + """ + + reader = vtk.vtkXMLGenericDataObjectReader() + reader.SetFileName(path) + reader.Update() + return reader.GetOutput() + + +def handle_unequal_strings( + tmp_path: Path, + current_test_name: str, + result: str, + reference_path: Path, +) -> None: + """Handle unequal string comparison. Print error message to console, write + new result file to temporary pytest directory and open VSCode diff tool if + local development is used. + + Args: + tmp_path: Temporary pytest directory + current_test_name: Name of the current test + result: "New" result string + reference_path: Path to "old" reference file + """ + + # save result string to file + result_path = tmp_path / (current_test_name + "_result.txt") + with open(result_path, "w") as file: + file.write(result) + print(f"Result string saved to: '{result_path}'.") + + # open VSCode diff tool if available + if shutil.which("code") is not None: + child = subprocess.Popen( + ["code", "--diff", result_path, reference_path], + stderr=subprocess.PIPE, + ) + child.communicate() + + +def compare_strings_with_tolerance( + reference: str, + result: str, + rtol: Optional[float], + atol: Optional[float], + string_splitter=" ", + output=False, +) -> Union[bool, Tuple[bool, str]]: + """Compare if two strings are identical within a given tolerance. + + Args: + reference: The reference string. + result: The result string. + rtol: The relative tolerance. + atol: The absolute tolerance. + string_splitter: With which string the strings are split. + output: Flag, if string containing failed comparison should be returned + + Returns: + bool: true if comparison is successful, raises AssertionError otherwise + bool, str: True if comparison is successful, False otherwise. If output + option is set, also return string containing information about failed + comparisons. + """ + + def get_return_values(flag, message): + """Get the data structure that shall be returned from this function.""" + if output: + return flag, message + else: + return flag + + try: + compare_strings_with_tolerance_assert( + reference, result, rtol, atol, string_splitter + ) + return get_return_values(True, "") + except AssertionError as error: + return get_return_values(False, str(error)) diff --git a/tests/coverage.config b/tests/coverage.config deleted file mode 100644 index e0f09835..00000000 --- a/tests/coverage.config +++ /dev/null @@ -1,6 +0,0 @@ -# Only analyze files from the repository, if MeshPy is installed in standard mode -[run] -include = - */site-packages/meshpy/* - */meshpy/tutorial/* - */meshpy/tests/* diff --git a/tests/coverage_local.config b/tests/coverage_local.config deleted file mode 100644 index 228813d4..00000000 --- a/tests/coverage_local.config +++ /dev/null @@ -1,6 +0,0 @@ -# Only analyze files from the repository, if MeshPy is installed in editable mode -[run] -include = - ../meshpy/* - ../tutorial/* - ../tests/* diff --git a/tests/performance_testing.py b/tests/performance_testing.py index d1c4a683..06e13acb 100644 --- a/tests/performance_testing.py +++ b/tests/performance_testing.py @@ -31,7 +31,6 @@ """Create a couple of different mesh cases and test the performance.""" import os -import socket import sys import time import warnings @@ -53,7 +52,7 @@ find_close_points, ) from meshpy.mesh_creation_functions.beam_basic_geometry import create_beam_mesh_line -from meshpy.utility import find_close_nodes, get_env_variable +from meshpy.utility import find_close_nodes def create_solid_block(file_path, nx, ny, nz): @@ -173,8 +172,7 @@ class TestPerformance(object): """A class to test meshpy performance.""" # Set expected test times. - expected_times = {} - expected_times["github-sisyphos-docker"] = { + expected_times = { "cubitpy_create_solid": 8.0, "meshpy_load_solid": 1.5, "meshpy_load_solid_full": 3.5, @@ -206,16 +204,10 @@ def time_function(self, name, funct, args=None, kwargs=None): kwargs = {} # Get the expected time for this function. - host = get_env_variable( - "PERFORMANCE_TESTING_HOST", default=socket.gethostname() - ) - if host in self.expected_times.keys(): - if name in self.expected_times[host].keys(): - expected_time = self.expected_times[host][name] - else: - raise ValueError("Function name {} not found!".format(name)) + if name in self.expected_times.keys(): + expected_time = self.expected_times[name] else: - raise ValueError("Host {} not found!".format(host)) + raise ValueError("Function name {} not found!".format(name)) # Time before the execution. start_time = time.time() diff --git a/tests/reference-files/test_dummy.dat b/tests/reference-files/test_dummy.dat new file mode 100644 index 00000000..e69de29b diff --git a/tests/reference-files/test_dummy_2_id.txt b/tests/reference-files/test_dummy_2_id.txt new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytest_testing_cosserat_curve.py b/tests/test_cosserat_curve.py similarity index 76% rename from tests/pytest_testing_cosserat_curve.py rename to tests/test_cosserat_curve.py index a1bb17ec..8b67856f 100644 --- a/tests/pytest_testing_cosserat_curve.py +++ b/tests/test_cosserat_curve.py @@ -32,9 +32,9 @@ import json import os +from pathlib import Path import numpy as np -import pytest import pyvista as pv import quaternion @@ -47,20 +47,12 @@ ) from meshpy.mesh_creation_functions import create_beam_mesh_helix -from .utils import ( - compare_test_result_pytest, - compare_vtk_pytest, - get_pytest_test_name, - testing_input, - testing_temp, -) - -def load_cosserat_curve_from_file(): +def load_cosserat_curve_from_file(reference_file_directory): """Load the centerline coordinates from the reference files and create the Cosserat curve.""" coordinates = np.loadtxt( - os.path.join(testing_input, "test_cosserat_curve_centerline.txt"), + os.path.join(reference_file_directory, "test_cosserat_curve_centerline.txt"), comments="#", delimiter=",", unpack=False, @@ -68,12 +60,12 @@ def load_cosserat_curve_from_file(): return CosseratCurve(coordinates) -def create_beam_solid_input_file(): +def create_beam_solid_input_file(reference_file_directory): """Create a beam and solid input file for testing purposes.""" mpy.import_mesh_full = True mesh = InputFile( - dat_file=os.path.join(testing_input, "test_cosserat_curve_mesh.dat") + dat_file=os.path.join(reference_file_directory, "test_cosserat_curve_mesh.dat") ) create_beam_mesh_helix( mesh, @@ -89,10 +81,12 @@ def create_beam_solid_input_file(): return mesh -def test_cosserat_curve_translate_and_rotate(): +def test_cosserat_curve_translate_and_rotate( + reference_file_directory, current_test_name +): """Test that a curve can be loaded, rotated and transformed.""" - curve = load_cosserat_curve_from_file() + curve = load_cosserat_curve_from_file(reference_file_directory) # Translate the curve so that the start is at the origin curve.translate(-curve.centerline_interpolation(5.0)) @@ -113,7 +107,7 @@ def test_cosserat_curve_translate_and_rotate(): def load_compare(name): """Load the compare files and return a numpy array.""" return np.loadtxt( - os.path.join(testing_input, f"{get_pytest_test_name()}_{name}.txt") + os.path.join(reference_file_directory, f"{current_test_name}_{name}.txt") ) assert np.allclose(sol_half_pos, load_compare("pos_half_ref"), rtol=1e-14) @@ -127,25 +121,32 @@ def load_compare(name): ) -def test_cosserat_curve_vtk_representation(): +def test_cosserat_curve_vtk_representation( + tmp_path, reference_file_directory, current_test_name, assert_results_equal +): """Test the vtk representation of the Cosserat curve.""" - result_name = os.path.join(testing_temp, get_pytest_test_name() + ".vtu") - curve = load_cosserat_curve_from_file() - pv.UnstructuredGrid(curve.get_pyvista_polyline()).save(result_name) - compare_vtk_pytest( - os.path.join(testing_input, get_pytest_test_name() + ".vtu"), - result_name, + reference_path = Path( + os.path.join(reference_file_directory, current_test_name + ".vtu") + ) + result_path = Path(os.path.join(tmp_path, current_test_name + ".vtu")) + + curve = load_cosserat_curve_from_file(reference_file_directory) + pv.UnstructuredGrid(curve.get_pyvista_polyline()).save(result_path) + + assert_results_equal( + reference_path, + result_path, rtol=1e-8, atol=1e-8, ) -def test_cosserat_curve_project_point(): +def test_cosserat_curve_project_point(reference_file_directory): """Test that the project point function works as expected.""" # Load the curve - curve = load_cosserat_curve_from_file() + curve = load_cosserat_curve_from_file(reference_file_directory) # Translate the curve so that the start is at the origin curve.translate(-curve.centerline_interpolation(0.0)) @@ -158,16 +159,16 @@ def test_cosserat_curve_project_point(): assert np.allclose(t_ref, curve.project_point([-5, 1, 1], t0=4.0), rtol=rtol) -def test_cosserat_mesh_transformation(): +def test_cosserat_mesh_transformation(reference_file_directory, current_test_name): """Test that the get_mesh_transformation function works as expected.""" - curve = load_cosserat_curve_from_file() + curve = load_cosserat_curve_from_file(reference_file_directory) pos, rot = curve.get_centerline_position_and_rotation(0) rot = Rotation.from_quaternion(quaternion.as_float_array(rot)) curve.translate(-pos) curve.translate([1, 2, 3]) - mesh = create_beam_solid_input_file() + mesh = create_beam_solid_input_file(reference_file_directory) pos, rot = get_mesh_transformation( curve, mesh.nodes, @@ -185,7 +186,8 @@ def test_cosserat_mesh_transformation(): def load_result(name): """Load the position and rotation results from the reference files.""" with open( - os.path.join(testing_input, f"{get_pytest_test_name()}_{name}.json"), "r" + os.path.join(reference_file_directory, f"{current_test_name}_{name}.json"), + "r", ) as f: return np.array(json.load(f)) @@ -199,11 +201,15 @@ def load_result(name): assert np.allclose(rot_ref, rot_np, rtol=1e-14) -def test_cosserat_curve_mesh_warp(): +def test_cosserat_curve_mesh_warp( + reference_file_directory, + get_corresponding_reference_file_path, + assert_results_equal, +): """Warp a balloon along a centerline.""" # Load the curve - curve = load_cosserat_curve_from_file() + curve = load_cosserat_curve_from_file(reference_file_directory) pos, rot = curve.get_centerline_position_and_rotation(0) rot = Rotation.from_quaternion(quaternion.as_float_array(rot)) curve.translate(-pos) @@ -211,7 +217,7 @@ def test_cosserat_curve_mesh_warp(): # Warp the mesh. The reference coordinate system is rotated such that z axis is the longitudinal direction, # and x and y are the first and second cross-section basis vectors respectively. - mesh = create_beam_solid_input_file() + mesh = create_beam_solid_input_file(reference_file_directory) warp_mesh_along_curve( mesh, curve, @@ -221,24 +227,25 @@ def test_cosserat_curve_mesh_warp(): ), ) - # Compare with the reference result - compare_test_result_pytest( - mesh.get_string(check_nox=False, header=False), rtol=1e-10 - ) + assert_results_equal(get_corresponding_reference_file_path(), mesh, rtol=1e-10) -def test_cosserat_curve_mesh_warp_transform_boundary_conditions(): +def test_cosserat_curve_mesh_warp_transform_boundary_conditions( + reference_file_directory, + get_corresponding_reference_file_path, + assert_results_equal, +): """Test the transform boundary creation function.""" # Load the curve - curve = load_cosserat_curve_from_file() + curve = load_cosserat_curve_from_file(reference_file_directory) pos, rot = curve.get_centerline_position_and_rotation(0) rot = Rotation.from_quaternion(quaternion.as_float_array(rot)) curve.translate(-pos) curve.translate([1, 2, 3]) # Load the mesh - mesh = create_beam_solid_input_file() + mesh = create_beam_solid_input_file(reference_file_directory) # Apply the transform boundary conditions create_transform_boundary_conditions( @@ -251,7 +258,6 @@ def test_cosserat_curve_mesh_warp_transform_boundary_conditions(): ), ) - # Compare with the reference result - compare_test_result_pytest( - mesh.get_string(check_nox=False, header=False), rtol=1e-8, atol=1e-8 + assert_results_equal( + get_corresponding_reference_file_path(), mesh, rtol=1e-8, atol=1e-8 ) diff --git a/tests/test_dummy.py b/tests/test_dummy.py new file mode 100644 index 00000000..2fbf6535 --- /dev/null +++ b/tests/test_dummy.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# ----------------------------------------------------------------------------- +# MeshPy: A beam finite element input generator +# +# MIT License +# +# Copyright (c) 2018-2024 +# Ivo Steinbrecher +# Institute for Mathematics and Computer-Based Simulation +# Universitaet der Bundeswehr Muenchen +# https://www.unibw.de/imcs-en +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ----------------------------------------------------------------------------- +"""Dummy test to demonstrate and ensure working pytest infrastructure.""" + +from pathlib import Path + +import pytest + + +def test_dummy( + reference_file_directory: Path, + tmp_path: Path, + current_test_name: str, + get_corresponding_reference_file_path, +) -> None: + """Dummy test to demonstrate pytest fixtures. + + Args: + reference_file_directory: path to the reference file directory + tmp_path: temporary path for testing + current_test_name: name of the current test + get_corresponding_reference_file_path: path to the corresponding reference file + """ + + # approach to get reference_file_directory + print("reference_file_directory: ", reference_file_directory) + + # approach to get temporary testing path + print( + "tmp_path: ", tmp_path + ) # pytest automatically keeps the last three runs for debugging and deletes everything that's older + + # approach to get the current test name + print("current_test_name: ", current_test_name) + + # approach to get the path to the corresponding reference .dat file + print( + "corresponding reference file path: ", get_corresponding_reference_file_path() + ) + + # approach to get the path to a reference file with other base name, additional identifier and extension + print( + "corresponding reference file path: ", + get_corresponding_reference_file_path( + reference_file_base_name="test_dummy_2", + additional_identifier="id", + extension="txt", + ), + ) + + assert True + + +@pytest.mark.fourc +def test_4C() -> None: + """Test with 4C.""" + + assert True + + +@pytest.mark.arborx +def test_ArborX() -> None: + """Test with ArborX.""" + + assert True + + +@pytest.mark.cubitpy +def test_CubitPy() -> None: + """Test with CubitPy.""" + + assert True + + +@pytest.mark.performance +def test_performance() -> None: + """Performance test.""" + + assert True diff --git a/tests/test_rotations.py b/tests/test_rotations.py new file mode 100644 index 00000000..174b5997 --- /dev/null +++ b/tests/test_rotations.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +# ----------------------------------------------------------------------------- +# MeshPy: A beam finite element input generator +# +# MIT License +# +# Copyright (c) 2018-2024 +# Ivo Steinbrecher +# Institute for Mathematics and Computer-Based Simulation +# Universitaet der Bundeswehr Muenchen +# https://www.unibw.de/imcs-en +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ----------------------------------------------------------------------------- +"""This script is used to test the functionality of the Rotation class in the +meshpy module.""" + +import numpy as np + +from meshpy import Rotation, mpy +from meshpy.rotation import smallest_rotation + + +def get_rotation_matrix(axis, alpha): + """Create a rotation about one of the Cartesian axis. + + Args + ---- + axis: int + 0 - x + 1 - y + 2 - z + angle: double rotation angle + + Return + ---- + rot3D: array(3x3) + Rotation matrix for this rotation + """ + c, s = np.cos(alpha), np.sin(alpha) + rot2D = np.array(((c, -s), (s, c))) + index = [np.mod(j, 3) for j in range(axis, axis + 3) if not j == axis] + rot3D = np.eye(3) + rot3D[np.ix_(index, index)] = rot2D + return rot3D + + +def test_cartesian_rotations(): + """Create a rotation in all 3 directions. + + And compare with the rotation matrix. + """ + + # Set default values for global parameters. + mpy.set_default_values() + + theta = 1.0 + # Loop per directions. + for i in range(3): + rot3D = get_rotation_matrix(i, theta) + axis = np.zeros(3) + axis[i] = 1 + angle = theta + rotation = Rotation(axis, angle) + + # Check if the rotation is the same if it is created from its own + # quaternion and then created from its own rotation matrix. + rotation = Rotation.from_quaternion(rotation.get_quaternion()) + rotation_matrix = Rotation.from_rotation_matrix(rotation.get_rotation_matrix()) + + np.allclose(np.linalg.norm(rot3D - rotation_matrix.get_rotation_matrix()), 0.0) + + +def test_euler_angles(): + """Create a rotation with Euler angles and compare to known results.""" + + # Set default values for global parameters. + mpy.set_default_values() + + # Euler angles. + alpha = 1.1 + beta = 1.2 * np.pi * 10 + gamma = -2.5 + + # Create the rotation with rotation matrices. + Rx = get_rotation_matrix(0, alpha) + Ry = get_rotation_matrix(1, beta) + Rz = get_rotation_matrix(2, gamma) + R_euler = Rz.dot(Ry.dot(Rx)) + + # Create the rotation with the Rotation object. + rotation_x = Rotation([1, 0, 0], alpha) + rotation_y = Rotation([0, 1, 0], beta) + rotation_z = Rotation([0, 0, 1], gamma) + rotation_euler = rotation_z * rotation_y * rotation_x + np.allclose(np.linalg.norm(R_euler - rotation_euler.get_rotation_matrix()), 0.0) + assert rotation_euler == Rotation.from_rotation_matrix(R_euler) + + # Direct formula for quaternions for Euler angles. + quaternion = np.zeros(4) + cy = np.cos(gamma * 0.5) + sy = np.sin(gamma * 0.5) + cr = np.cos(alpha * 0.5) + sr = np.sin(alpha * 0.5) + cp = np.cos(beta * 0.5) + sp = np.sin(beta * 0.5) + quaternion[0] = cy * cr * cp + sy * sr * sp + quaternion[1] = cy * sr * cp - sy * cr * sp + quaternion[2] = cy * cr * sp + sy * sr * cp + quaternion[3] = sy * cr * cp - cy * sr * sp + assert Rotation.from_quaternion(quaternion) == rotation_euler + assert Rotation.from_quaternion(quaternion) == Rotation.from_quaternion( + rotation_euler.get_quaternion() + ) + assert Rotation.from_quaternion(quaternion) == Rotation.from_rotation_matrix( + R_euler + ) + + +def test_negative_angles(): + """Check if a rotation is created correctly if a negative angle or a large + angle is given.""" + + # Set default values for global parameters. + mpy.set_default_values() + + vector = 10 * np.array([-1.234243, -2.334343, -1.123123]) + phi = -12.152101868665 + rot = Rotation(vector, phi) + for i in range(2): + assert rot == Rotation(vector, phi + 2 * i * np.pi) + + rot = Rotation.from_rotation_vector(vector) + q = rot.q + assert rot == Rotation.from_quaternion(-q) + assert Rotation.from_quaternion(q) == Rotation.from_quaternion(-q) + + +def test_inverse_rotation(): + """Test the inv() function for rotations.""" + + # Set default values for global parameters. + mpy.set_default_values() + + # Define test rotation. + rot = Rotation([1, 2, 3], 2) + + # Check if inverse rotation gets identity rotation. Use two different + # constructors for identity rotation. + assert Rotation.from_rotation_vector([0, 0, 0]) == rot * rot.inv() + assert Rotation() == rot * rot.inv() + + # Check that there is no warning or error when getting the vector for + # an identity rotation. + (rot * rot.inv()).get_rotation_vector() + + +def test_rotation_vector(): + """Test if the rotation vector functions give a correct result.""" + + # Calculate rotation vector and quaternion. + axis = np.array([1.36568, -2.96784, 3.23346878]) + angle = 0.7189467 + rotation_vector = angle * axis / np.linalg.norm(axis) + q = np.zeros(4) + q[0] = np.cos(angle / 2) + q[1:] = np.sin(angle / 2) * axis / np.linalg.norm(axis) + + # Check that the rotation object from the quaternion and rotation + # vector are equal. + rotation_from_vec = Rotation.from_rotation_vector(rotation_vector) + assert Rotation.from_quaternion(q) == rotation_from_vec + assert Rotation(axis, angle) == rotation_from_vec + + # Check that the same rotation vector is returned after being converted + # to a quaternion. + np.testing.assert_array_less( + np.linalg.norm(rotation_vector - rotation_from_vec.get_rotation_vector()), + mpy.eps_quaternion, + ) + + +def test_rotation_operator_overload(): + """Test if the operator overloading gives a correct result.""" + + # Calculate rotation and vector. + axis = np.array([1.36568, -2.96784, 3.23346878]) + angle = 0.7189467 + rot = Rotation(axis, angle) + vector = [2.234234, -4.213234, 6.345234] + + # Check the result of the operator overloading. + result_vector = np.dot(rot.get_rotation_matrix(), vector) + np.testing.assert_array_less( + np.linalg.norm(result_vector - rot * vector), mpy.eps_quaternion + ) + np.testing.assert_array_less( + np.linalg.norm(result_vector - rot * np.array(vector)), mpy.eps_quaternion + ) + + +def test_rotation_matrix(): + """Test if the correct quaternions are generated from a rotation matrix.""" + + # Do one calculation for each case in + # Rotation().from_rotation_matrix(). + vectors = [ + [[1, 0, 0], [0, -1, 0]], + [[0, 0, 1], [0, 1, 0]], + [[-1, 0, 0], [0, 1, 0]], + [[0, 1, 0], [0, 0, 1]], + ] + + for t1, t2 in vectors: + rot = Rotation().from_basis(t1, t2) + t1_rot = rot * [1, 0, 0] + t2_rot = rot * [0, 1, 0] + np.testing.assert_array_less(np.linalg.norm(t1 - t1_rot), mpy.eps_quaternion) + np.testing.assert_array_less(np.linalg.norm(t2 - t2_rot), mpy.eps_quaternion) + + +def test_transformation_matrix(): + """Test that the transformation matrix is computed correctly.""" + + rotation_vector_large = [1.0, 2.0, np.pi / 5.0] + rotation_large = Rotation.from_rotation_vector(rotation_vector_large) + rotation_vector_small = ( + rotation_vector_large + / np.linalg.norm(rotation_vector_large) + / 10.0 + * mpy.eps_quaternion + ) + rotation_small = Rotation.from_rotation_vector(rotation_vector_small) + + # Test transformation matrix + transformation_matrix_large_reference = np.array( + [ + [0.5959488405656389, 0.49803685445056006, -0.9422331516950085], + [-0.13028167626739845, 0.8717652242030102, 0.6155336966099826], + [1.0577668483049911, -0.3844663033900173, 0.5403060272710478], + ] + ) + assert np.allclose( + rotation_large.get_transformation_matrix(), + transformation_matrix_large_reference, + atol=mpy.eps_quaternion, + rtol=0.0, + ) + assert np.allclose( + rotation_small.get_transformation_matrix(), + np.identity(3), + atol=mpy.eps_quaternion, + rtol=0.0, + ) + + # Test transformation matrix inverse + transformation_matrix_inverse_large_reference = np.array( + [ + [0.44154375784863675, 0.05812896596538626, 0.7037804689849043], + [0.4501610612860693, 0.8227612782872283, -0.15228520755418612], + [-0.5440964474342915, 0.4716532506554118, 0.36463746593568075], + ] + ) + assert np.allclose( + rotation_large.get_transformation_matrix_inv(), + transformation_matrix_inverse_large_reference, + atol=mpy.eps_quaternion, + rtol=0.0, + ) + assert np.allclose( + rotation_small.get_transformation_matrix_inv(), + np.identity(3), + atol=mpy.eps_quaternion, + rtol=0.0, + ) + + +def test_smallest_rotation_triad(): + """Test that the smallest rotation triad is calculated correctly.""" + + # Get the triad obtained by a smallest rotation from an arbitrary triad + # onto an arbitrary tangent vector. + rot = Rotation([1, 2, 3], 0.431 * np.pi) + tan = [2.0, 3.0, -1.0] + rot_smallest = smallest_rotation(rot, tan) + + rot_smallest_ref = [ + 0.853329730651268, + 0.19771093216880734, + 0.25192421451158936, + 0.4114279380770031, + ] + np.testing.assert_array_less( + np.linalg.norm(rot_smallest.q - rot_smallest_ref), mpy.eps_quaternion + ) + + +def test_error_accumulation_multiplication(): + """Test that error accumulation of successive multiplications of rotations + does not affect the results.""" + + rotation_1 = Rotation([1, 2, 3], 0.3) + rotation_2 = Rotation([1, -1, -2], np.pi / 6) + rotation_3 = Rotation([-1, -2, -3], 7 * np.pi / 17) + rotation = Rotation() + for _ in range(100): + rotation = rotation_1 * rotation * rotation_2 + rotation = rotation * rotation_3 + + q_ref = [ + -0.38478914485223104, + -0.0385171948379694, + -0.49122781649072017, + -0.780479962594468, + ] + assert np.allclose(q_ref, rotation.q, atol=1e-14) + + +def test_error_accumulation_smallest_rotation(): + """Test that error accumulation of successive smallest rotation mappings + does not affect the results. + + Calculate the smallest rotation onto a vector and then rotate that + vector "away" to calculate the next smallest rotation and so on... + """ + + tangent = [0.9, 0.1, -0.3] + rotation_old = Rotation([1, 2, 3], 0.3) + + for _ in range(50): + rotation_new = smallest_rotation(rotation_old, tangent) + tangent = rotation_new * rotation_old.inv() * tangent + rotation_old = rotation_new + + q_ref = [ + 0.6329069205124062, + 0.13331392718187732, + -0.5128773537467728, + 0.5644581887089211, + ] + assert np.allclose(q_ref, rotation_new.q, atol=1e-14) diff --git a/tests/testing_utility.py b/tests/test_utility.py similarity index 57% rename from tests/testing_utility.py rename to tests/test_utility.py index 1c605f31..8e290d38 100644 --- a/tests/testing_utility.py +++ b/tests/test_utility.py @@ -30,47 +30,29 @@ # ----------------------------------------------------------------------------- """Test utilities of MeshPy.""" -import unittest - from meshpy.node import Node from meshpy.utility import is_node_on_plane -class TestUtilities(unittest.TestCase): - """Test utilities from the meshpy.utility module.""" - - def test_is_node_on_plane(self): - """Test if node on plane function works properly.""" - - # node on plane with origin_distance - node = Node([1.0, 1.0, 1.0]) - self.assertTrue( - is_node_on_plane(node, normal=[0.0, 0.0, 1.0], origin_distance=1.0) - ) - - # node on plane with point_on_plane - node = Node([1.0, 1.0, 1.0]) - self.assertTrue( - is_node_on_plane( - node, normal=[0.0, 0.0, 5.0], point_on_plane=[5.0, 5.0, 1.0] - ) - ) +def test_is_node_on_plane(): + """Test if node on plane function works properly.""" - # node not on plane with origin_distance - node = Node([13.5, 14.5, 15.5]) - self.assertFalse( - is_node_on_plane(node, normal=[0.0, 0.0, 1.0], origin_distance=5.0) - ) + # node on plane with origin_distance + node = Node([1.0, 1.0, 1.0]) + assert is_node_on_plane(node, normal=[0.0, 0.0, 1.0], origin_distance=1.0) - # node not on plane with point_on_plane - node = Node([13.5, 14.5, 15.5]) - self.assertFalse( - is_node_on_plane( - node, normal=[0.0, 0.0, 5.0], point_on_plane=[5.0, 5.0, 1.0] - ) - ) + # node on plane with point_on_plane + node = Node([1.0, 1.0, 1.0]) + assert is_node_on_plane( + node, normal=[0.0, 0.0, 5.0], point_on_plane=[5.0, 5.0, 1.0] + ) + # node not on plane with origin_distance + node = Node([13.5, 14.5, 15.5]) + assert not is_node_on_plane(node, normal=[0.0, 0.0, 1.0], origin_distance=5.0) -if __name__ == "__main__": - # Execution part of script. - unittest.main() + # node not on plane with point_on_plane + node = Node([13.5, 14.5, 15.5]) + assert not is_node_on_plane( + node, normal=[0.0, 0.0, 5.0], point_on_plane=[5.0, 5.0, 1.0] + ) diff --git a/tests/testing_rotations.py b/tests/testing_rotations.py deleted file mode 100644 index 48aa19a6..00000000 --- a/tests/testing_rotations.py +++ /dev/null @@ -1,386 +0,0 @@ -# -*- coding: utf-8 -*- -# ----------------------------------------------------------------------------- -# MeshPy: A beam finite element input generator -# -# MIT License -# -# Copyright (c) 2018-2024 -# Ivo Steinbrecher -# Institute for Mathematics and Computer-Based Simulation -# Universitaet der Bundeswehr Muenchen -# https://www.unibw.de/imcs-en -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ----------------------------------------------------------------------------- -"""This script is used to test the functionality of the Rotation class in the -meshpy module.""" - -import unittest - -import numpy as np - -from meshpy import Rotation, mpy -from meshpy.rotation import smallest_rotation - - -class TestRotation(unittest.TestCase): - """This class tests the implementation of the Rotation class.""" - - def rotation_matrix(self, axis, alpha): - """Create a rotation about one of the Cartesian axis. - - Args - ---- - axis: int - 0 - x - 1 - y - 2 - z - angle: double rotation angle - - Return - ---- - rot3D: array(3x3) - Rotation matrix for this rotation - """ - c, s = np.cos(alpha), np.sin(alpha) - rot2D = np.array(((c, -s), (s, c))) - index = [np.mod(j, 3) for j in range(axis, axis + 3) if not j == axis] - rot3D = np.eye(3) - rot3D[np.ix_(index, index)] = rot2D - return rot3D - - def test_cartesian_rotations(self): - """Create a rotation in all 3 directions. - - And compare with the rotation matrix. - """ - - # Set default values for global parameters. - mpy.set_default_values() - - theta = 1.0 - # Loop per directions. - for i in range(3): - rot3D = self.rotation_matrix(i, theta) - axis = np.zeros(3) - axis[i] = 1 - angle = theta - rotation = Rotation(axis, angle) - - # Check if the rotation is the same if it is created from its own - # quaternion and then created from its own rotation matrix. - rotation = Rotation.from_quaternion(rotation.get_quaternion()) - rotation_matrix = Rotation.from_rotation_matrix( - rotation.get_rotation_matrix() - ) - - self.assertAlmostEqual( - np.linalg.norm(rot3D - rotation_matrix.get_rotation_matrix()), 0.0 - ) - - def test_euler_angles(self): - """Create a rotation with Euler angles and compare to known results.""" - - # Set default values for global parameters. - mpy.set_default_values() - - # Euler angles. - alpha = 1.1 - beta = 1.2 * np.pi * 10 - gamma = -2.5 - - # Create the rotation with rotation matrices. - Rx = self.rotation_matrix(0, alpha) - Ry = self.rotation_matrix(1, beta) - Rz = self.rotation_matrix(2, gamma) - R_euler = Rz.dot(Ry.dot(Rx)) - - # Create the rotation with the Rotation object. - rotation_x = Rotation([1, 0, 0], alpha) - rotation_y = Rotation([0, 1, 0], beta) - rotation_z = Rotation([0, 0, 1], gamma) - rotation_euler = rotation_z * rotation_y * rotation_x - self.assertAlmostEqual( - np.linalg.norm(R_euler - rotation_euler.get_rotation_matrix()), 0.0 - ) - self.assertTrue(rotation_euler == Rotation.from_rotation_matrix(R_euler)) - - # Direct formula for quaternions for Euler angles. - quaternion = np.zeros(4) - cy = np.cos(gamma * 0.5) - sy = np.sin(gamma * 0.5) - cr = np.cos(alpha * 0.5) - sr = np.sin(alpha * 0.5) - cp = np.cos(beta * 0.5) - sp = np.sin(beta * 0.5) - quaternion[0] = cy * cr * cp + sy * sr * sp - quaternion[1] = cy * sr * cp - sy * cr * sp - quaternion[2] = cy * cr * sp + sy * sr * cp - quaternion[3] = sy * cr * cp - cy * sr * sp - self.assertTrue(Rotation.from_quaternion(quaternion) == rotation_euler) - self.assertTrue( - Rotation.from_quaternion(quaternion) - == Rotation.from_quaternion(rotation_euler.get_quaternion()) - ) - self.assertTrue( - Rotation.from_quaternion(quaternion) - == Rotation.from_rotation_matrix(R_euler) - ) - - def test_negative_angles(self): - """Check if a rotation is created correctly if a negative angle or a - large angle is given.""" - - # Set default values for global parameters. - mpy.set_default_values() - - vector = 10 * np.array([-1.234243, -2.334343, -1.123123]) - phi = -12.152101868665 - rot = Rotation(vector, phi) - for i in range(2): - self.assertTrue(rot == Rotation(vector, phi + 2 * i * np.pi)) - - rot = Rotation.from_rotation_vector(vector) - q = rot.q - self.assertTrue(rot == Rotation.from_quaternion(-q)) - self.assertTrue(Rotation.from_quaternion(q) == Rotation.from_quaternion(-q)) - - def test_inverse_rotation(self): - """Test the inv() function for rotations.""" - - # Set default values for global parameters. - mpy.set_default_values() - - # Define test rotation. - rot = Rotation([1, 2, 3], 2) - - # Check if inverse rotation gets identity rotation. Use two different - # constructors for identity rotation. - self.assertTrue(Rotation.from_rotation_vector([0, 0, 0]) == rot * rot.inv()) - self.assertTrue(Rotation() == rot * rot.inv()) - - # Check that there is no warning or error when getting the vector for - # an identity rotation. - (rot * rot.inv()).get_rotation_vector() - - def test_rotation_vector(self): - """Test if the rotation vector functions give a correct result.""" - - # Calculate rotation vector and quaternion. - axis = np.array([1.36568, -2.96784, 3.23346878]) - angle = 0.7189467 - rotation_vector = angle * axis / np.linalg.norm(axis) - q = np.zeros(4) - q[0] = np.cos(angle / 2) - q[1:] = np.sin(angle / 2) * axis / np.linalg.norm(axis) - - # Check that the rotation object from the quaternion and rotation - # vector are equal. - rotation_from_vec = Rotation.from_rotation_vector(rotation_vector) - self.assertTrue(Rotation.from_quaternion(q) == rotation_from_vec) - self.assertTrue(Rotation(axis, angle) == rotation_from_vec) - - # Check that the same rotation vector is returned after being converted - # to a quaternion. - self.assertLess( - np.linalg.norm(rotation_vector - rotation_from_vec.get_rotation_vector()), - mpy.eps_quaternion, - "test_rotation_vector", - ) - - def test_rotation_operator_overload(self): - """Test if the operator overloading gives a correct result.""" - - # Calculate rotation and vector. - axis = np.array([1.36568, -2.96784, 3.23346878]) - angle = 0.7189467 - rot = Rotation(axis, angle) - vector = [2.234234, -4.213234, 6.345234] - - # Check the result of the operator overloading. - result_vector = np.dot(rot.get_rotation_matrix(), vector) - self.assertLess( - np.linalg.norm(result_vector - rot * vector), - mpy.eps_quaternion, - "test_rotation_vector", - ) - self.assertLess( - np.linalg.norm(result_vector - rot * np.array(vector)), - mpy.eps_quaternion, - "test_rotation_vector", - ) - - def test_rotation_matrix(self): - """Test if the correct quaternions are generated from a rotation - matrix.""" - - # Do one calculation for each case in - # Rotation().from_rotation_matrix(). - vectors = [ - [[1, 0, 0], [0, -1, 0]], - [[0, 0, 1], [0, 1, 0]], - [[-1, 0, 0], [0, 1, 0]], - [[0, 1, 0], [0, 0, 1]], - ] - - for t1, t2 in vectors: - rot = Rotation().from_basis(t1, t2) - t1_rot = rot * [1, 0, 0] - t2_rot = rot * [0, 1, 0] - self.assertLess( - np.linalg.norm(t1 - t1_rot), - mpy.eps_quaternion, - "test_rotation_matrix: compare t1", - ) - self.assertLess( - np.linalg.norm(t2 - t2_rot), - mpy.eps_quaternion, - "test_rotation_matrix: compare t2", - ) - - def test_transformation_matrix(self): - """Test that the transformation matrix is computed correctly.""" - - rotation_vector_large = [1.0, 2.0, np.pi / 5.0] - rotation_large = Rotation.from_rotation_vector(rotation_vector_large) - rotation_vector_small = ( - rotation_vector_large - / np.linalg.norm(rotation_vector_large) - / 10.0 - * mpy.eps_quaternion - ) - rotation_small = Rotation.from_rotation_vector(rotation_vector_small) - - # Test transformation matrix - transformation_matrix_large_reference = np.array( - [ - [0.5959488405656389, 0.49803685445056006, -0.9422331516950085], - [-0.13028167626739845, 0.8717652242030102, 0.6155336966099826], - [1.0577668483049911, -0.3844663033900173, 0.5403060272710478], - ] - ) - self.assertTrue( - np.allclose( - rotation_large.get_transformation_matrix(), - transformation_matrix_large_reference, - atol=mpy.eps_quaternion, - rtol=0.0, - ) - ) - self.assertTrue( - np.allclose( - rotation_small.get_transformation_matrix(), - np.identity(3), - atol=mpy.eps_quaternion, - rtol=0.0, - ) - ) - - # Test transformation matrix inverse - transformation_matrix_inverse_large_reference = np.array( - [ - [0.44154375784863675, 0.05812896596538626, 0.7037804689849043], - [0.4501610612860693, 0.8227612782872283, -0.15228520755418612], - [-0.5440964474342915, 0.4716532506554118, 0.36463746593568075], - ] - ) - self.assertTrue( - np.allclose( - rotation_large.get_transformation_matrix_inv(), - transformation_matrix_inverse_large_reference, - atol=mpy.eps_quaternion, - rtol=0.0, - ) - ) - self.assertTrue( - np.allclose( - rotation_small.get_transformation_matrix_inv(), - np.identity(3), - atol=mpy.eps_quaternion, - rtol=0.0, - ) - ) - - def test_smallest_rotation_triad(self): - """Test that the smallest rotation triad is calculated correctly.""" - - # Get the triad obtained by a smallest rotation from an arbitrary triad - # onto an arbitrary tangent vector. - rot = Rotation([1, 2, 3], 0.431 * np.pi) - tan = [2.0, 3.0, -1.0] - rot_smallest = smallest_rotation(rot, tan) - - rot_smallest_ref = [ - 0.853329730651268, - 0.19771093216880734, - 0.25192421451158936, - 0.4114279380770031, - ] - self.assertLess( - np.linalg.norm(rot_smallest.q - rot_smallest_ref), mpy.eps_quaternion - ) - - def test_error_accumulation_multiplication(self): - """Test that error accumulation of successive multiplications of - rotations does not affect the results.""" - - rotation_1 = Rotation([1, 2, 3], 0.3) - rotation_2 = Rotation([1, -1, -2], np.pi / 6) - rotation_3 = Rotation([-1, -2, -3], 7 * np.pi / 17) - rotation = Rotation() - for _ in range(100): - rotation = rotation_1 * rotation * rotation_2 - rotation = rotation * rotation_3 - - q_ref = [ - -0.38478914485223104, - -0.0385171948379694, - -0.49122781649072017, - -0.780479962594468, - ] - assert np.allclose(q_ref, rotation.q, atol=1e-14) - - def test_error_accumulation_smallest_rotation(self): - """Test that error accumulation of successive smallest rotation - mappings does not affect the results. - - Calculate the smallest rotation onto a vector and then rotate that - vector "away" to calculate the next smallest rotation and so on... - """ - - tangent = [0.9, 0.1, -0.3] - rotation_old = Rotation([1, 2, 3], 0.3) - - for _ in range(50): - rotation_new = smallest_rotation(rotation_old, tangent) - tangent = rotation_new * rotation_old.inv() * tangent - rotation_old = rotation_new - - q_ref = [ - 0.6329069205124062, - 0.13331392718187732, - -0.5128773537467728, - 0.5644581887089211, - ] - assert np.allclose(q_ref, rotation_new.q, atol=1e-14) - - -if __name__ == "__main__": - # Execution part of script. - unittest.main() diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index e461941a..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,422 +0,0 @@ -# -*- coding: utf-8 -*- -# ----------------------------------------------------------------------------- -# MeshPy: A beam finite element input generator -# -# MIT License -# -# Copyright (c) 2018-2024 -# Ivo Steinbrecher -# Institute for Mathematics and Computer-Based Simulation -# Universitaet der Bundeswehr Muenchen -# https://www.unibw.de/imcs-en -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ----------------------------------------------------------------------------- -"""Define utility functions for the testing process.""" - -import os -import shutil -import subprocess - -import numpy as np -import vtk -from vtk_utils.compare_grids import compare_grids - -from meshpy.utility import get_env_variable - - -def get_pytest_test_name(): - """Return the name of the current pytest test.""" - return os.environ.get("PYTEST_CURRENT_TEST").split(":")[-1].split(" ")[0] - - -def skip_fail_four_c(self): - """Check if a 4C executable can be found. - - If TESTING_GITHUB_4C==1 then we raise an error if we cant find the - 4C executable, otherwise the test is skipped - """ - - message = "Can not find 4C executable" - four_c_path = get_env_variable("MESHPY_FOUR_C_EXE", default="") - if not os.path.isfile(four_c_path): - if get_env_variable("TESTING_GITHUB_4C", default="0") == "1": - raise ImportError(message) - else: - self.skipTest(message) - - -def skip_fail_arborx(self): - """Check if ArborX geometric search can be loaded. - - If TESTING_GITHUB_ARBORX==1 then we raise an error if we cant load - ArborX, otherwise the test is skipped - """ - - from meshpy.geometric_search.geometric_search_arborx import arborx_available - - message = "Can not import ArborX geometric search" - if not arborx_available: - if get_env_variable("TESTING_GITHUB_ARBORX", default="0") == "1": - raise ImportError(message) - else: - self.skipTest(message) - - -def skip_fail_cubitpy(self): - """Check if CubitPy can be loaded. - - If TESTING_GITHUB_CUBITPY==1 then we raise an error if we cant load - CubitPy, otherwise the test is skipped - """ - message = "Can not import and initialize CubitPy" - try: - from cubitpy import CubitPy - - cubit = CubitPy() - except Exception: - if get_env_variable("TESTING_GITHUB_CUBITPY", default="0") == "1": - raise ImportError(message) - else: - self.skipTest(message) - - -# Define the testing paths -testing_path = os.path.abspath(os.path.dirname(__file__)) -testing_input = os.path.join(testing_path, "reference-files") -testing_temp = os.path.join(testing_path, "testing-tmp") - -# Check and clean the temporary directory. -os.makedirs(testing_temp, exist_ok=True) - - -def empty_testing_directory(): - """Delete all files in the testing directory, if it exists.""" - if os.path.isdir(testing_temp): - for the_file in os.listdir(testing_temp): - file_path = os.path.join(testing_temp, the_file) - try: - if os.path.isfile(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print(e) - - -def compare_string_tolerance( - reference, compare, *, rtol=None, atol=None, split_string=" " -): - """Compare two strings, all floating point values will be compared with a - tolerance.""" - - rtol = 0.0 if rtol is None else rtol - atol = 0.0 if atol is None else atol - - lines_reference = reference.strip().split("\n") - lines_compare = compare.strip().split("\n") - n_reference = len(lines_reference) - n_compare = len(lines_compare) - if n_reference == n_compare: - # Loop over each line in the file - for i in range(n_reference): - line_reference = lines_reference[i].strip().split(split_string) - line_compare = lines_compare[i].strip().split(split_string) - n_items_reference = len(line_reference) - n_items_compare = len(line_compare) - if n_items_reference == n_items_compare: - # Loop over each entry in the line - for j in range(n_items_reference): - try: - reference_number = float(line_reference[j].strip()) - compare_number = float(line_compare[j].strip()) - if np.isclose( - reference_number, compare_number, rtol=rtol, atol=atol - ): - pass - else: - return False - except ValueError: - if line_reference[j].strip() != line_compare[j].strip(): - return False - else: - return False - else: - return False - - return True - - -def compare_test_result( - self, - result_string, - *, - extension="dat", - reference_file_base_name=None, - additional_identifier=None, - **kwargs, -): - """Compare a created string in a test with the reference results. The - reference results are stored in a file made up of the test name. The - filename will always end with "_reference". - - Args - ---- - result_string: str - String to compare with a reference file - reference_file_base_name: str - Base name of the reference file to compare with. Defaults to the name of the - current test - additional_identifier: str - Will be added after the base reference file name - extension: str - File extension of the reference file - """ - - if reference_file_base_name is None: - reference_file_base_name = self._testMethodName - - if additional_identifier is not None: - reference_file_base_name += f"_{additional_identifier}" - - if extension is not None: - reference_file_base_name += "." + extension - - reference_file_path = os.path.join(testing_input, reference_file_base_name) - - # Compare the results - compare_strings(self, reference_file_path, result_string, **kwargs) - - -def compare_strings(self, reference, compare, *, rtol=None, atol=None, **kwargs): - """Compare two stings. - - If they are not identical open a comparison and show the - differences. - """ - - def check_is_file_get_string(item): - """Check if the input data is a file that exists or a string.""" - is_file = os.path.isfile(item) - if is_file: - with open(item, "r") as myfile: - string = myfile.read() - else: - string = item - return is_file, string - - reference_is_file, reference_string = check_is_file_get_string(reference) - compare_is_file, compare_string = check_is_file_get_string(compare) - - if rtol is None and atol is None: - # Check if the strings are equal, if not compare the differences and - # fail the test. - is_equal = reference_string.strip() == compare_string.strip() - else: - is_equal = compare_string_tolerance( - reference_string, compare_string, rtol=rtol, atol=atol, **kwargs - ) - - message = f"Test: {self._testMethodName}" - if not is_equal: - # Check if temporary directory exists, and creates it if necessary. - os.makedirs(testing_temp, exist_ok=True) - - def get_compare_paths(item, is_file, string): - """Get the paths of the files to compare. - - If a string was given create a file with the string in it. - """ - if is_file: - file = item - else: - file = os.path.join( - testing_temp, - "{}_failed_test_compare.dat".format(self._testMethodName), - ) - with open(file, "w") as f: - f.write(string) - return file - - reference_file = get_compare_paths( - reference, reference_is_file, reference_string - ) - compare_file = get_compare_paths(compare, compare_is_file, compare_string) - - message += f"\nCompare strings failed. Files:\n ref: {reference_file}\n res: {compare_file}" - - if shutil.which("code") is not None: - child = subprocess.Popen( - ["code", "--diff", reference_file, compare_file], stderr=subprocess.PIPE - ) - child.communicate() - else: - result = subprocess.run( - ["diff", reference_file, compare_file], stdout=subprocess.PIPE - ) - self._testMethodName += "\n\nDiff:\n" + result.stdout.decode("utf-8") - - # Check the results. - self.assertTrue(is_equal, message) - - -def compare_vtk(self, path_1, path_2, *, rtol=1e-14, atol=1e-14): - """Compare two vtk files and raise an error if they are not equal.""" - - def get_vtk(path): - """Return a vtk object for the file at path.""" - reader = vtk.vtkXMLGenericDataObjectReader() - reader.SetFileName(path) - reader.Update() - return reader.GetOutput() - - compare = compare_grids( - get_vtk(path_1), get_vtk(path_2), output=True, rtol=rtol, atol=atol - ) - self.assertTrue(compare[0], msg="\n".join(compare[1])) - - -def compare_test_result_pytest( - result_string, - *, - extension="dat", - reference_file_base_name=None, - additional_identifier=None, - **kwargs, -): - """Compare a created string in a test with the reference results. The - reference results are stored in a file made up of the test name. The - filename will always end with "_reference". - - Args - ---- - result_string: str - String to compare with a reference file - reference_file_base_name: str - Base name of the reference file to compare with. Defaults to the name of the - current test - additional_identifier: str - Will be added after the base reference file name - extension: str - File extension of the reference file - """ - - if reference_file_base_name is None: - reference_file_base_name = get_pytest_test_name() - - if additional_identifier is not None: - reference_file_base_name += f"_{additional_identifier}" - - if extension is not None: - reference_file_base_name += "." + extension - - reference_file_path = os.path.join(testing_input, reference_file_base_name) - - # Compare the results - compare_strings_pytest(reference_file_path, result_string, **kwargs) - - -def compare_strings_pytest(reference, compare, *, rtol=None, atol=None, **kwargs): - """Compare two stings. - - If they are not identical open a comparison and show the - differences. - """ - - def check_is_file_get_string(item): - """Check if the input data is a file that exists or a string.""" - is_file = os.path.isfile(item) - if is_file: - with open(item, "r") as myfile: - string = myfile.read() - else: - string = item - return is_file, string - - reference_is_file, reference_string = check_is_file_get_string(reference) - compare_is_file, compare_string = check_is_file_get_string(compare) - - if rtol is None and atol is None: - # Check if the strings are equal, if not compare the differences and - # fail the test. - is_equal = reference_string.strip() == compare_string.strip() - else: - is_equal = compare_string_tolerance( - reference_string, compare_string, rtol=rtol, atol=atol, **kwargs - ) - - test_name = get_pytest_test_name() - message = f"Test: {test_name}" - if not is_equal: - # Check if temporary directory exists, and creates it if necessary. - os.makedirs(testing_temp, exist_ok=True) - - def get_compare_paths(item, is_file, string, name): - """Get the paths of the files to compare. - - If a string was given create a file with the string in it. - """ - if is_file: - file = item - else: - file = os.path.join( - testing_temp, - f"{test_name}_failed_test_{name}.dat", - ) - with open(file, "w") as f: - f.write(string) - return file - - reference_file = get_compare_paths( - reference, reference_is_file, reference_string, "reference" - ) - compare_file = get_compare_paths( - compare, compare_is_file, compare_string, "compare" - ) - - message += f"\nCompare strings failed. Files:\n ref: {reference_file}\n res: {compare_file}" - - if shutil.which("code") is not None: - child = subprocess.Popen( - ["code", "--diff", reference_file, compare_file], stderr=subprocess.PIPE - ) - child.communicate() - else: - result = subprocess.run( - ["diff", reference_file, compare_file], stdout=subprocess.PIPE - ) - - # Check the results. - assert is_equal, message - - -def compare_vtk_pytest(path_1, path_2, *, rtol=1e-14, atol=1e-14): - """Compare two vtk files and raise an error if they are not equal.""" - - def get_vtk(path): - """Return a vtk object for the file at path.""" - reader = vtk.vtkXMLGenericDataObjectReader() - reader.SetFileName(path) - reader.Update() - return reader.GetOutput() - - compare = compare_grids( - get_vtk(path_1), get_vtk(path_2), output=True, rtol=rtol, atol=atol - ) - assert compare[0], "\n".join(compare[1])