-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Evaluation code of references is slightly off #4559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for raising. Here are some thoughts:
I think lightning is likely to face similar issues but for now I think it's worth patching batch_size=1 when |
Thanks for sharing your thoughts @datumbox I'm not sure that setting the batch_size to 1 will avoid the issue: we'll still get duplicated images if
source:
import os
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __getitem__(self, index):
return index
def __len__(self):
return 10
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method="env://")
dataset = MyDataset()
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=1)
for e in dataloader:
for r in range(world_size):
torch.distributed.barrier()
if r == rank:
print(f"I'm rank {rank} and I got {e})") |
You can try setting batch_size=1 and running on single GPU if you really want to avoid the issue. You could also do There is a trick you can use for some types of models that allows you to identify padded records from real data directly from the output without storing additional info.The trick it to pad your batch to a fix size using NaN values. The prediction outputs of these records will be NaNs as well, which makes it easy to distinguish and filter from real data without additional info. Here are types of models where this is not possible:
Here is a POC which takes all the Classification and Segmentation models of TorchVision and shows that such padding is possible: import torch
import torchvision
# image 0 has valid data. Image 1 is used only for pad and contains only nans
batch = torch.randn((2, 3, 224, 224))
batch[1] *= float("nan")
contains_nans = lambda t: torch.isnan(t).any()
get_models = lambda module: [k for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
print("Classification:")
for model_name in get_models(torchvision.models):
model = torchvision.models.__dict__[model_name](pretrained=False).eval()
try:
with torch.no_grad():
out = model(batch)
assert not contains_nans(out[0]), "Found nans on the real image!"
assert contains_nans(out[1]), "No nan data found on the padded image!"
except Exception as e:
print(f"[FAIL] {model_name}: {e}")
continue
print(f"[SUCCESS] {model_name}")
print("Segmentation:")
for model_name in get_models(torchvision.models.segmentation):
model = torchvision.models.segmentation.__dict__[model_name](pretrained=False, pretrained_backbone=False).eval()
try:
with torch.no_grad():
out = model(batch)['out']
assert not contains_nans(out[0]), "Found nans on the real image!"
assert contains_nans(out[1]), "No nan data found on the padded image!"
except Exception as e:
print(f"[FAIL] {model_name}: {e}")
continue
print(f"[SUCCESS] {model_name}") Personally I find this too much of a hassle. The mitigations I described on my previous comment should be enough for most users/use-cases. |
As far as I can tell, the only correct way is to have I agree that in some cases, e.g. during training, we don't necessarily need to have 100% accurate results and the current behaviour is good enough. But I think we should still aim at providing exact validation results on all models. It's very likely that some users / researchers are using this code without realizing these subtle caveats, and we need to prevent them from reporting incorrect results. |
I don't think there is any disagreement on how each option works or on their pros/cons. The reference scripts are there as a starting point for people to implement their own solutions and adjust them to their needs. If their need is to report precise numbers, then they would have to implement one of the solutions outlined above. Currently the scripts are tuned to our needs which are to train our pre-trained models. In our use-case the datasets are big and the padded examples won't change noticeably the accuracy. So unless one proves that the effect on the accuracy of the current models is massive, I'm not sure implementing a solution that will increase the complexity of the scripts would be worth it. Perhaps @fmassa can provide background info on whether this was ever a problem to any of the research work he was involved, as many of these scripts are ported from those repos. |
In general, people do need to report precise numbers. I think we do, and I think our users expect accurate, correct results as well. OTOH I don't think it's fair for us to expect our users to identify these subtle bugs or to be aware of them, so I wouldn't expect them to use anything else but what we provide them with. We're telling users that they can rely on our reference scripts for training - why should it be any different for validation? If we identify a bug, we owe it to our users to try to fix it. I disagree with the sentiment that this doesn't matter, however small the difference in results might be. |
Hi, There are a lot of points being discussed in here. Let's split this up in parts: Simple way of fixing this for classificationJust ask the users to evaluate on single-GPU, for any batch size. We handle this case properly already (and we keep Batch size 1This is actually only necessary for detection models, as we perform padding on the images. So in order to keep results consistent we evaluate with batch size 1. But note that the detection references are the only one that currently handle distributed properly. What we do is to return the image id together with the image / target in the dataset, and keep a map of the IMO this is the most robust way of fixing this, but this introduces changes to the rest of the pipeline. If we note it carefully, a key part of this approach is to return the |
@fmassa I agree a simple approach is preferred here. Increasing further the complexity of the reference scripts is going towards the opposite direction of where we aim. The STL seems a much better place to handle such details. @NicolasHug I think it was obvious but to avoid misinterpretations of what I said, by "precise numbers" I meant differences on the 5th decimal point. We typically report results up to 1st or 3rd digit on models, so I believe changes wont be visible. But you are welcome to quantify the change and show otherwise. |
I ran
65 times on
Regarding the suggestion to use a single GPU and set shuffle=False, EDIT: thanks @datumbox for checking #4587. It seems clear from there that running the validation on a single GPU isn't enough to get exact and consistent results. We still need to have |
Thanks for running experiments to quantify the size of the effect. The variation is not massive but noticeable. It might be worth applying some of the temporary solutions described above. To summarize we could initialize the For a permanent fix, this is probably something we should investigate on the new recipes where we will have the opportunity to inherit code. I remember @kazhang raising this in the past in one of our previous discussions. We need to find a solution that works on all models and all datasets, for example the aforementioned |
Did you mean shuffle=False? For the rest I agree on all fronts :) I'll submit a PR soon. |
Doh... Yes I meant |
An alternative solution to fix the issue while still allowing for distributed settings would be to set I opened #4600 to illustrate this Thoughts? |
I'm starting to think that the variance reported in #4559 (comment) isn't due to the bug that we're discussing here, but to something else - possibly the model, or the evaluation process. Using the same command as above and printing the batch size that each worker is using, it shows that all 8 processes use a batch size of 32 for 195 batches, and a batch size of 10 for the last batch. The dataset has 50k samples and we have
So all samples are properly treated in my experiment above, there's no duplicated sample, and there's no missing sample either. So this seems to suggest that there is some extra source of randomness that we're not controlling somehow. |
Thanks for clarifying and digging deeper. It's great that you do this investigation because before deciding how to move forward, we will need to quantify again the impact (variation on stats) and attribute it to the right factor. Concerning your prototype, I left comments on the PR. It's a nice approach but let's defer any decisions for after you conclude your above investigation. |
OK, things are a bit clearer to me now. As @fmassa's suggested (thanks!) the variance might come from the non-deterministic algorithms that are in use. I set the following: torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True and now I'm getting consistent results across batch-sizes. I also patched the code to figure out how many samples are effectively processed. Here are the results for different number of processes, all of them using the default batch-size = 32:
So:
Thanks both for your input!! Considering most of the variance is captured by disabling stochastic algorithms as above I would suggest to just set these flags to True if I think we could also raise a warning if not exactly diff --git a/references/classification/train.py b/references/classification/train.py
index a71d337a..de520fb3 100644
--- a/references/classification/train.py
+++ b/references/classification/train.py
@@ -54,6 +54,13 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}"
+ def _reduce(val):
+ val = torch.tensor([val], dtype=torch.int, device="cuda")
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(val)
+ return val.item()
+
+ n_samples = 0
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
@@ -68,7 +75,12 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
+ n_samples += batch_size
# gather the stats from all processes
+
+ n_samples = _reduce(n_samples)
+ print(f"We processed {n_samples} in total") |
There is a subtle known bug in the evaluation code of the classification references (and other references as well, but not all):
vision/references/classification/train.py
Lines 65 to 66 in 261cbf7
It deserves some attention, because it's easy to miss and yet can impact our reported results, and those of research papers.
As the comment above describes, when computing the accuracy of the model on a validation set in a distributed setting, some images will be counted more than once if
len(dataset)
isn't divisible bybatch_size * world_size
1.On top of that, since the
test_sampler
usesshuffle=True
by default, the duplicated images aren't even the same across executions, which means that evaluating the same model on the same dataset can lead to different results every time.Should we try to fix this, or should we just leave it and wait for the new lightning recipes to handle it? And as a follow-up question, is there a builtin way in lightning to mitigate this at all? (I'm not familiar with lightning, so this one may not make sense.)
cc @datumbox
Footnotes
For example if we have 10 images and 2 workers with a batch_size of 3, we will have something like:
↩The text was updated successfully, but these errors were encountered: