7
7
import sys
8
8
import tempfile
9
9
import time
10
+ from enum import Enum
10
11
from pathlib import Path
11
12
from types import TracebackType
12
13
from typing import (
@@ -444,14 +445,18 @@ def assert_bash_exec(
444
445
class bash_env_saved :
445
446
counter : int = 0
446
447
448
+ class saved_state (Enum ):
449
+ ChangesDetected = 1
450
+ ChangesIgnored = 2
451
+
447
452
def __init__ (self , bash : pexpect .spawn , sendintr : bool = False ):
448
453
bash_env_saved .counter += 1
449
454
self .prefix : str = "_comp__test_%d" % bash_env_saved .counter
450
455
451
456
self .bash = bash
452
457
self .cwd_changed : bool = False
453
- self .saved_shopt : Dict [str , int ] = {}
454
- self .saved_variables : Dict [str , int ] = {}
458
+ self .saved_shopt : Dict [str , bash_env_saved . saved_state ] = {}
459
+ self .saved_variables : Dict [str , bash_env_saved . saved_state ] = {}
455
460
self .sendintr = sendintr
456
461
457
462
self .noexcept : bool = False
@@ -516,14 +521,19 @@ def _save_cwd(self):
516
521
self ._copy_variable ("PWD" , "%s_OLDPWD" % self .prefix )
517
522
518
523
def _check_shopt (self , name : str ):
524
+ if (
525
+ self .saved_shopt [name ]
526
+ != bash_env_saved .saved_state .ChangesDetected
527
+ ):
528
+ return
519
529
self ._safe_assert (
520
530
'[[ $(shopt -p %s) == "${%s_NEWSHOPT_%s}" ]]'
521
531
% (name , self .prefix , name ),
522
532
)
523
533
524
534
def _unprotect_shopt (self , name : str ):
525
535
if name not in self .saved_shopt :
526
- self .saved_shopt [name ] = 1
536
+ self .saved_shopt [name ] = bash_env_saved . saved_state . ChangesDetected
527
537
self ._safe_exec (
528
538
"%s_OLDSHOPT_%s=$(shopt -p %s || true)"
529
539
% (self .prefix , name , name ),
@@ -538,6 +548,11 @@ def _protect_shopt(self, name: str):
538
548
)
539
549
540
550
def _check_variable (self , varname : str ):
551
+ if (
552
+ self .saved_variables [varname ]
553
+ != bash_env_saved .saved_state .ChangesDetected
554
+ ):
555
+ return
541
556
try :
542
557
self ._safe_assert (
543
558
'[[ ${%s-%s} == "${%s_NEWVAR_%s-%s}" ]]'
@@ -556,7 +571,9 @@ def _check_variable(self, varname: str):
556
571
557
572
def _unprotect_variable (self , varname : str ):
558
573
if varname not in self .saved_variables :
559
- self .saved_variables [varname ] = 1
574
+ self .saved_variables [
575
+ varname
576
+ ] = bash_env_saved .saved_state .ChangesDetected
560
577
self ._copy_variable (
561
578
varname , "%s_OLDVAR_%s" % (self .prefix , varname )
562
579
)
@@ -581,13 +598,6 @@ def _restore_env(self):
581
598
self ._unset_variable ("%s_OLDPWD" % self .prefix )
582
599
self .cwd_changed = False
583
600
584
- for name in self .saved_shopt :
585
- self ._check_shopt (name )
586
- self ._safe_exec ('eval "$%s_OLDSHOPT_%s"' % (self .prefix , name ))
587
- self ._unset_variable ("%s_OLDSHOPT_%s" % (self .prefix , name ))
588
- self ._unset_variable ("%s_NEWSHOPT_%s" % (self .prefix , name ))
589
- self .saved_shopt = {}
590
-
591
601
for varname in self .saved_variables :
592
602
self ._check_variable (varname )
593
603
self ._copy_variable (
@@ -597,6 +607,13 @@ def _restore_env(self):
597
607
self ._unset_variable ("%s_NEWVAR_%s" % (self .prefix , varname ))
598
608
self .saved_variables = {}
599
609
610
+ for name in self .saved_shopt :
611
+ self ._check_shopt (name )
612
+ self ._safe_exec ('eval "$%s_OLDSHOPT_%s"' % (self .prefix , name ))
613
+ self ._unset_variable ("%s_OLDSHOPT_%s" % (self .prefix , name ))
614
+ self ._unset_variable ("%s_NEWSHOPT_%s" % (self .prefix , name ))
615
+ self .saved_shopt = {}
616
+
600
617
self .noexcept = False
601
618
if self .captured_error :
602
619
raise self .captured_error
@@ -616,13 +633,23 @@ def shopt(self, name: str, value: bool):
616
633
self ._safe_exec ("shopt -u %s" % name )
617
634
self ._protect_shopt (name )
618
635
636
+ def save_shopt (self , name : str ):
637
+ self ._unprotect_shopt (name )
638
+ self .saved_shopt [name ] = bash_env_saved .saved_state .ChangesIgnored
639
+
619
640
def write_variable (self , varname : str , new_value : str , quote : bool = True ):
620
641
if quote :
621
642
new_value = shlex .quote (new_value )
622
643
self ._unprotect_variable (varname )
623
644
self ._safe_exec ("%s=%s" % (varname , new_value ))
624
645
self ._protect_variable (varname )
625
646
647
+ def save_variable (self , varname : str ):
648
+ self ._unprotect_variable (varname )
649
+ self .saved_variables [
650
+ varname
651
+ ] = bash_env_saved .saved_state .ChangesIgnored
652
+
626
653
# TODO: We may restore the "export" attribute as well though it is
627
654
# not currently tested in "diff_env"
628
655
def write_env (self , envname : str , new_value : str , quote : bool = True ):
0 commit comments