Skip to content

Commit 4411787

Browse files
committed
test(bash_env_saved): add "save_{shopt,variable}" to allow any changes
1 parent eb94fb1 commit 4411787

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

test/t/conftest.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import tempfile
99
import time
10+
from enum import Enum
1011
from pathlib import Path
1112
from types import TracebackType
1213
from typing import (
@@ -444,14 +445,18 @@ def assert_bash_exec(
444445
class bash_env_saved:
445446
counter: int = 0
446447

448+
class saved_state(Enum):
449+
ChangesDetected = 1
450+
ChangesIgnored = 2
451+
447452
def __init__(self, bash: pexpect.spawn, sendintr: bool = False):
448453
bash_env_saved.counter += 1
449454
self.prefix: str = "_comp__test_%d" % bash_env_saved.counter
450455

451456
self.bash = bash
452457
self.cwd_changed: bool = False
453-
self.saved_shopt: Dict[str, int] = {}
454-
self.saved_variables: Dict[str, int] = {}
458+
self.saved_shopt: Dict[str, bash_env_saved.saved_state] = {}
459+
self.saved_variables: Dict[str, bash_env_saved.saved_state] = {}
455460
self.sendintr = sendintr
456461

457462
self.noexcept: bool = False
@@ -516,14 +521,19 @@ def _save_cwd(self):
516521
self._copy_variable("PWD", "%s_OLDPWD" % self.prefix)
517522

518523
def _check_shopt(self, name: str):
524+
if (
525+
self.saved_shopt[name]
526+
!= bash_env_saved.saved_state.ChangesDetected
527+
):
528+
return
519529
self._safe_assert(
520530
'[[ $(shopt -p %s) == "${%s_NEWSHOPT_%s}" ]]'
521531
% (name, self.prefix, name),
522532
)
523533

524534
def _unprotect_shopt(self, name: str):
525535
if name not in self.saved_shopt:
526-
self.saved_shopt[name] = 1
536+
self.saved_shopt[name] = bash_env_saved.saved_state.ChangesDetected
527537
self._safe_exec(
528538
"%s_OLDSHOPT_%s=$(shopt -p %s || true)"
529539
% (self.prefix, name, name),
@@ -538,6 +548,11 @@ def _protect_shopt(self, name: str):
538548
)
539549

540550
def _check_variable(self, varname: str):
551+
if (
552+
self.saved_variables[varname]
553+
!= bash_env_saved.saved_state.ChangesDetected
554+
):
555+
return
541556
try:
542557
self._safe_assert(
543558
'[[ ${%s-%s} == "${%s_NEWVAR_%s-%s}" ]]'
@@ -556,7 +571,9 @@ def _check_variable(self, varname: str):
556571

557572
def _unprotect_variable(self, varname: str):
558573
if varname not in self.saved_variables:
559-
self.saved_variables[varname] = 1
574+
self.saved_variables[
575+
varname
576+
] = bash_env_saved.saved_state.ChangesDetected
560577
self._copy_variable(
561578
varname, "%s_OLDVAR_%s" % (self.prefix, varname)
562579
)
@@ -581,13 +598,6 @@ def _restore_env(self):
581598
self._unset_variable("%s_OLDPWD" % self.prefix)
582599
self.cwd_changed = False
583600

584-
for name in self.saved_shopt:
585-
self._check_shopt(name)
586-
self._safe_exec('eval "$%s_OLDSHOPT_%s"' % (self.prefix, name))
587-
self._unset_variable("%s_OLDSHOPT_%s" % (self.prefix, name))
588-
self._unset_variable("%s_NEWSHOPT_%s" % (self.prefix, name))
589-
self.saved_shopt = {}
590-
591601
for varname in self.saved_variables:
592602
self._check_variable(varname)
593603
self._copy_variable(
@@ -597,6 +607,13 @@ def _restore_env(self):
597607
self._unset_variable("%s_NEWVAR_%s" % (self.prefix, varname))
598608
self.saved_variables = {}
599609

610+
for name in self.saved_shopt:
611+
self._check_shopt(name)
612+
self._safe_exec('eval "$%s_OLDSHOPT_%s"' % (self.prefix, name))
613+
self._unset_variable("%s_OLDSHOPT_%s" % (self.prefix, name))
614+
self._unset_variable("%s_NEWSHOPT_%s" % (self.prefix, name))
615+
self.saved_shopt = {}
616+
600617
self.noexcept = False
601618
if self.captured_error:
602619
raise self.captured_error
@@ -616,13 +633,23 @@ def shopt(self, name: str, value: bool):
616633
self._safe_exec("shopt -u %s" % name)
617634
self._protect_shopt(name)
618635

636+
def save_shopt(self, name: str):
637+
self._unprotect_shopt(name)
638+
self.saved_shopt[name] = bash_env_saved.saved_state.ChangesIgnored
639+
619640
def write_variable(self, varname: str, new_value: str, quote: bool = True):
620641
if quote:
621642
new_value = shlex.quote(new_value)
622643
self._unprotect_variable(varname)
623644
self._safe_exec("%s=%s" % (varname, new_value))
624645
self._protect_variable(varname)
625646

647+
def save_variable(self, varname: str):
648+
self._unprotect_variable(varname)
649+
self.saved_variables[
650+
varname
651+
] = bash_env_saved.saved_state.ChangesIgnored
652+
626653
# TODO: We may restore the "export" attribute as well though it is
627654
# not currently tested in "diff_env"
628655
def write_env(self, envname: str, new_value: str, quote: bool = True):

0 commit comments

Comments
 (0)