Skip to content

Evaluation code of references is slightly off #4559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
NicolasHug opened this issue Oct 7, 2021 · 16 comments · Fixed by #4609
Closed

Evaluation code of references is slightly off #4559

NicolasHug opened this issue Oct 7, 2021 · 16 comments · Fixed by #4609

Comments

@NicolasHug
Copy link
Member

NicolasHug commented Oct 7, 2021

There is a subtle known bug in the evaluation code of the classification references (and other references as well, but not all):

# FIXME need to take into account that the datasets
# could have been padded in distributed setup

It deserves some attention, because it's easy to miss and yet can impact our reported results, and those of research papers.

As the comment above describes, when computing the accuracy of the model on a validation set in a distributed setting, some images will be counted more than once if len(dataset) isn't divisible by batch_size * world_size 1.

On top of that, since the test_sampler uses shuffle=True by default, the duplicated images aren't even the same across executions, which means that evaluating the same model on the same dataset can lead to different results every time.

Should we try to fix this, or should we just leave it and wait for the new lightning recipes to handle it? And as a follow-up question, is there a builtin way in lightning to mitigate this at all? (I'm not familiar with lightning, so this one may not make sense.)

cc @datumbox

Footnotes

  1. For example if we have 10 images and 2 workers with a batch_size of 3, we will have something like:

    worker1: img1, img2, img3
    worker2: img4, img5, img6
    worker1: img7, img8, img9
    worker2: img10, **img1, img2** 
                      ^^^^^^^^^
     "padding": duplicated images which will affect the validation accuracy
    
@datumbox
Copy link
Contributor

datumbox commented Oct 7, 2021

Thanks for raising.

Here are some thoughts:

  1. Reporting slightly biased val stats during the training process is less concerning. I agree that this is concerning when we report final results.
  2. We could change the batch_size to 1 similar to what we do on detection. If that slows down the epochs too much we can try doing it only when --test-only is passed. This way we will be able to report correct stats for our models.
  3. Setting shuffle=False for testing is IMO something worth doing anyway to get more determinism.
  4. We could also try setting drop_last=True on the DataLoader during training. Not sure it improves things by much though.

I think lightning is likely to face similar issues but for now I think it's worth patching batch_size=1 when --test-only and settings shuffle=False on the validation sampler. This can be a good bootcamp task BTW. Let me know what you think.

@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 7, 2021

Thanks for sharing your thoughts @datumbox

I'm not sure that setting the batch_size to 1 will avoid the issue: we'll still get duplicated images if len(dataset) isn't divisible by world_size. For example with 10 samples, 3 workers and batch_size=1 I think we get something like this:

I'm rank 0 and I got tensor([0]))
I'm rank 1 and I got tensor([1]))
I'm rank 2 and I got tensor([2]))
I'm rank 0 and I got tensor([3]))
I'm rank 1 and I got tensor([4]))
I'm rank 2 and I got tensor([5]))
I'm rank 0 and I got tensor([6]))
I'm rank 1 and I got tensor([7]))
I'm rank 2 and I got tensor([8]))
I'm rank 0 and I got tensor([9]))
I'm rank 1 and I got tensor([0]))  <-- dup
I'm rank 2 and I got tensor([1]))  <-- dup

source:

torchrun --nproc_per_node=3 --nnodes 1 lol.py
import os
import torch

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __getitem__(self, index):
        return index

    def __len__(self):
        return 10


local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method="env://")

dataset = MyDataset()
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=1)
for e in dataloader:
    for r in range(world_size):
        torch.distributed.barrier()
        if r == rank:
            print(f"I'm rank {rank} and I got {e})")

@datumbox
Copy link
Contributor

datumbox commented Oct 7, 2021

You can try setting batch_size=1 and running on single GPU if you really want to avoid the issue. You could also do drop_last=True as said above to drop the last sample, but now you will lose img10. In 99.999%, this doesn't worth the effort.

There is a trick you can use for some types of models that allows you to identify padded records from real data directly from the output without storing additional info.The trick it to pad your batch to a fix size using NaN values. The prediction outputs of these records will be NaNs as well, which makes it easy to distinguish and filter from real data without additional info.

Here are types of models where this is not possible:

  • Unstable models that produce NaNs randomly
  • Models that have post-processing steps that filter or aggregate the CNN outputs (like in Detection)
  • GANs that have to produce results on training mode instead of inference mode
  • Exotic layers/techniques that keep using mini-batch stats across images even on inference mode

Here is a POC which takes all the Classification and Segmentation models of TorchVision and shows that such padding is possible:

import torch
import torchvision

# image 0 has valid data. Image 1 is used only for pad and contains only nans
batch = torch.randn((2, 3, 224, 224))
batch[1] *= float("nan")

contains_nans = lambda t: torch.isnan(t).any()
get_models = lambda module: [k for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


print("Classification:")
for model_name in get_models(torchvision.models):
    model = torchvision.models.__dict__[model_name](pretrained=False).eval()
    try:
        with torch.no_grad():
            out = model(batch)

            assert not contains_nans(out[0]), "Found nans on the real image!"
            assert contains_nans(out[1]), "No nan data found on the padded image!"
    except Exception as e:
        print(f"[FAIL] {model_name}: {e}")
        continue
    print(f"[SUCCESS] {model_name}")


print("Segmentation:")
for model_name in get_models(torchvision.models.segmentation):
    model = torchvision.models.segmentation.__dict__[model_name](pretrained=False, pretrained_backbone=False).eval()
    try:
        with torch.no_grad():
            out = model(batch)['out']

            assert not contains_nans(out[0]), "Found nans on the real image!"
            assert contains_nans(out[1]), "No nan data found on the padded image!"
    except Exception as e:
        print(f"[FAIL] {model_name}: {e}")
        continue
    print(f"[SUCCESS] {model_name}")

Personally I find this too much of a hassle. The mitigations I described on my previous comment should be enough for most users/use-cases.

@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 7, 2021

You could also do drop_last=True as said above to drop the last sample, but now you will lose img10

drop_last=True will still drop some images, but it can drop more than just the last image. For example with len(dataset) == 10, batch_size=2 and world_size=6, the last 4 images will be dropped. Agreed, this is a weird setting, but it can still significantly influence the validation result, especially on small datasets.

As far as I can tell, the only correct way is to have batch_size == world_size == 1 as you already noted, or to have a combination such that len(dataset) % (batch_size * world_size) == 0.

I agree that in some cases, e.g. during training, we don't necessarily need to have 100% accurate results and the current behaviour is good enough.

But I think we should still aim at providing exact validation results on all models. It's very likely that some users / researchers are using this code without realizing these subtle caveats, and we need to prevent them from reporting incorrect results.

@datumbox
Copy link
Contributor

datumbox commented Oct 7, 2021

I don't think there is any disagreement on how each option works or on their pros/cons. The reference scripts are there as a starting point for people to implement their own solutions and adjust them to their needs. If their need is to report precise numbers, then they would have to implement one of the solutions outlined above.

Currently the scripts are tuned to our needs which are to train our pre-trained models. In our use-case the datasets are big and the padded examples won't change noticeably the accuracy. So unless one proves that the effect on the accuracy of the current models is massive, I'm not sure implementing a solution that will increase the complexity of the scripts would be worth it.

Perhaps @fmassa can provide background info on whether this was ever a problem to any of the research work he was involved, as many of these scripts are ported from those repos.

@NicolasHug
Copy link
Member Author

If their need is to report precise numbers, then they would have to implement one of the solutions outlined above.

In general, people do need to report precise numbers. I think we do, and I think our users expect accurate, correct results as well. OTOH I don't think it's fair for us to expect our users to identify these subtle bugs or to be aware of them, so I wouldn't expect them to use anything else but what we provide them with.

We're telling users that they can rely on our reference scripts for training - why should it be any different for validation?

If we identify a bug, we owe it to our users to try to fix it. I disagree with the sentiment that this doesn't matter, however small the difference in results might be.

@fmassa
Copy link
Member

fmassa commented Oct 7, 2021

Hi,

There are a lot of points being discussed in here. Let's split this up in parts:

Simple way of fixing this for classification

Just ask the users to evaluate on single-GPU, for any batch size. We handle this case properly already (and we keep drop_last=False in the dataloader).

Batch size 1

This is actually only necessary for detection models, as we perform padding on the images. So in order to keep results consistent we evaluate with batch size 1.

But note that the detection references are the only one that currently handle distributed properly. What we do is to return the image id together with the image / target in the dataset, and keep a map of the {image_id: acc} for every image in the dataset per process. Then we aggregate over all processes (which will remove duplicates) and then we average out for the final numbers.

IMO this is the most robust way of fixing this, but this introduces changes to the rest of the pipeline.

If we note it carefully, a key part of this approach is to return the image_id together with the data in the dataset -- the refactoring of the datasets does it natively, so it will open the door for us to get this implemented in a clean and concise way everywhere

@datumbox
Copy link
Contributor

datumbox commented Oct 7, 2021

@fmassa I agree a simple approach is preferred here. Increasing further the complexity of the reference scripts is going towards the opposite direction of where we aim. The STL seems a much better place to handle such details.

@NicolasHug I think it was obvious but to avoid misinterpretations of what I said, by "precise numbers" I meant differences on the 5th decimal point. We typically report results up to 1st or 3rd digit on models, so I believe changes wont be visible. But you are welcome to quantify the change and show otherwise.

@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 11, 2021

I ran

torchrun --nproc_per_node=8 references/classification/train.py --model resnet18 --test-only --pretrained

65 times on main and here are the results. The variance affects the results around the 2nd / 3rd decimal, so IMHO it's high enough for us to address this.

Acc@1: resnet18
count    65.000000
mean     69.762923
std       0.003555
min      69.754000
25%      69.760000
50%      69.764000
75%      69.766000
max      69.770000
dtype: float64

Acc@5: resnet18
count    65.000000
mean     89.070246
std       0.002194
min      89.068000
25%      89.068000
50%      89.070000
75%      89.072000
max      89.076000
dtype: float64

Regarding the suggestion to use a single GPU and set shuffle=False, I agree this is a reasonable temporary solution, until we address this problem properly, and provided that this yields consistently the same results. However it looks like there are additional sources of randomness in the results - would you guys mind taking a look at #4587 ?

EDIT: thanks @datumbox for checking #4587. It seems clear from there that running the validation on a single GPU isn't enough to get exact and consistent results. We still need to have len(dataset) % (batch_size * world_size) == 0.

@datumbox
Copy link
Contributor

datumbox commented Oct 11, 2021

Thanks for running experiments to quantify the size of the effect.

The variation is not massive but noticeable. It might be worth applying some of the temporary solutions described above. To summarize we could initialize the test_sampler with shuffle=False and when --test-only is provided, we could check if the len(dataset) % (batch_size * world_size) == 0. If that's not true, one approach is to make the necessary changes on args (such as setting batch and word size to 1) to ensure the user gets correct stats. Another way is to raise a warning about the variation and explain more visibly why this happens; I think the current comment on the code is a bit hidden.

For a permanent fix, this is probably something we should investigate on the new recipes where we will have the opportunity to inherit code. I remember @kazhang raising this in the past in one of our previous discussions. We need to find a solution that works on all models and all datasets, for example the aforementioned image_id approach is only possible on COCO.

@NicolasHug
Copy link
Member Author

we could initialize the test_sampler with shuffle=True

Did you mean shuffle=False?

For the rest I agree on all fronts :)

I'll submit a PR soon.

@datumbox
Copy link
Contributor

Did you mean shuffle=False?

Doh... Yes I meant False. I'll edit in place to avoid confusion of future readers. :p

@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 11, 2021

An alternative solution to fix the issue while still allowing for distributed settings would be to set drop_last=True, leverage ddp for as many samples as we can, and then account for the remaining samples on the main process.

I opened #4600 to illustrate this

Thoughts?

@NicolasHug
Copy link
Member Author

I'm starting to think that the variance reported in #4559 (comment) isn't due to the bug that we're discussing here, but to something else - possibly the model, or the evaluation process.

Using the same command as above and printing the batch size that each worker is using, it shows that all 8 processes use a batch size of 32 for 195 batches, and a batch size of 10 for the last batch. The dataset has 50k samples and we have

(195 * 32 * 8) + (1 * 10 * 8) == 50000

So all samples are properly treated in my experiment above, there's no duplicated sample, and there's no missing sample either.

So this seems to suggest that there is some extra source of randomness that we're not controlling somehow.

@datumbox
Copy link
Contributor

Thanks for clarifying and digging deeper. It's great that you do this investigation because before deciding how to move forward, we will need to quantify again the impact (variation on stats) and attribute it to the right factor.

Concerning your prototype, I left comments on the PR. It's a nice approach but let's defer any decisions for after you conclude your above investigation.

@NicolasHug
Copy link
Member Author

OK, things are a bit clearer to me now.

As @fmassa's suggested (thanks!) the variance might come from the non-deterministic algorithms that are in use. I set the following:

    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True

and now I'm getting consistent results across batch-sizes. I also patched the code to figure out how many samples are effectively processed. Here are the results for different number of processes, all of them using the default batch-size = 32:

Test:  Acc@1 69.762 Acc@5 89.076, 50000 samples processed -- world_size = 1 or 2 or 4 or 5 or 8
Test:  Acc@1 69.760 Acc@5 89.075, 50004 samples processed -- world_size = 6
Test:  Acc@1 69.761 Acc@5 89.074, 50001 samples processed -- world_size = 3 or 7

So:

  • my previous understanding that we really need to have len(dataset) % (batch_size * world_size) == 0 is wrong. Sometimes, the DataLoader can reduce the batch size of the last batches so that exactly len(dataset) samples are processed. Setting world_size == 1 as @fmassa suggested above should indeed always process exactly len(dataset) samples, no matter the batch size.
  • There's still some visible variance in the result across world_sizes * batch_size values, but it's not as high as what the previous analysis in Evaluation code of references is slightly off #4559 (comment) would suggest. That being said, for other dataset sizes the number of duplicated samples may be higher. I might be wrong but I think that in the worst case, we can have at least world_size - 1 duplicated samples. For small dataset sizes this might impact the result quite a bit, but this doesn't matter too much for our datasets.

Thanks both for your input!!

Considering most of the variance is captured by disabling stochastic algorithms as above I would suggest to just set these flags to True if test_only is True, and to keep #4600 in the back of our mind for the next version of the references / recipes.

I think we could also raise a warning if not exactly len(dataset) samples have been processed, to let the user know that the results might be slightly biased. This would require a small patch like this:

diff --git a/references/classification/train.py b/references/classification/train.py
index a71d337a..de520fb3 100644
--- a/references/classification/train.py
+++ b/references/classification/train.py
@@ -54,6 +54,13 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
     model.eval()
     metric_logger = utils.MetricLogger(delimiter="  ")
     header = f"Test: {log_suffix}"
+    def _reduce(val):
+        val = torch.tensor([val], dtype=torch.int, device="cuda")
+        torch.distributed.barrier()
+        torch.distributed.all_reduce(val)
+        return val.item()
+
+    n_samples = 0
     with torch.no_grad():
         for image, target in metric_logger.log_every(data_loader, print_freq, header):
             image = image.to(device, non_blocking=True)
@@ -68,7 +75,12 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
             metric_logger.update(loss=loss.item())
             metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
             metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
+            n_samples += batch_size
     # gather the stats from all processes
+
+    n_samples = _reduce(n_samples)
+    print(f"We processed {n_samples} in total")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants