diff --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp index af3c1f4d1673c..050f5232c0408 100644 --- a/compiler-rt/lib/dfsan/dfsan_custom.cpp +++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp @@ -2198,50 +2198,12 @@ struct Formatter { return retval; } - int scan() { - char *tmp_fmt = build_format_string(true); - int read_count = 0; - int retval = sscanf(str + str_off, tmp_fmt, &read_count); - if (retval > 0) { - if (-1 == num_scanned) - num_scanned = 0; - num_scanned += retval; - } - free(tmp_fmt); - return read_count; - } - - template - int scan(T arg) { - char *tmp_fmt = build_format_string(true); - int read_count = 0; - int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count); - if (retval > 0) { - if (-1 == num_scanned) - num_scanned = 0; - num_scanned += retval; - } - free(tmp_fmt); - return read_count; - } - - // with_n -> toggles adding %n on/off; off by default - char *build_format_string(bool with_n = false) { + char *build_format_string() { size_t fmt_size = fmt_cur - fmt_start + 1; - size_t add_size = 0; - if (with_n) - add_size = 2; - char *new_fmt = (char *)malloc(fmt_size + 1 + add_size); + char *new_fmt = (char *)malloc(fmt_size + 1); assert(new_fmt); internal_memcpy(new_fmt, fmt_start, fmt_size); - if (!with_n) { - new_fmt[fmt_size] = '\0'; - } else { - new_fmt[fmt_size] = '%'; - new_fmt[fmt_size + 1] = 'n'; - new_fmt[fmt_size + 2] = '\0'; - } - + new_fmt[fmt_size] = '\0'; return new_fmt; } @@ -2467,6 +2429,102 @@ static int format_buffer(char *str, size_t size, const char *fmt, return formatter.str_off; } +// Scans a chunk either a constant string or a single format directive (e.g., +// '%.3f'). +struct Scanner { + Scanner(char *str_, const char *fmt_, size_t size_) + : str(str_), + str_off(0), + size(size_), + fmt_start(fmt_), + fmt_cur(fmt_), + width(-1), + num_scanned(0), + skip(false) {} + + // Consumes a chunk of ordinary characters. + // Returns number of matching ordinary characters. + // Returns -1 if the match failed. + // In format strings, a space will match multiple spaces. + int check_match_ordinary() { + char *tmp_fmt = build_format_string_with_n(); + int read_count = -1; + sscanf(str + str_off, tmp_fmt, &read_count); + free(tmp_fmt); + if (read_count > 0) { + str_off += read_count; + } + return read_count; + } + + int scan() { + char *tmp_fmt = build_format_string_with_n(); + int read_count = 0; + int retval = sscanf(str + str_off, tmp_fmt, &read_count); + free(tmp_fmt); + if (retval > 0) { + num_scanned += retval; + } + return read_count; + } + + template + int scan(T arg) { + char *tmp_fmt = build_format_string_with_n(); + int read_count = 0; + int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count); + free(tmp_fmt); + if (retval > 0) { + num_scanned += retval; + } + return read_count; + } + + // Adds %n onto current format string to measure length. + char *build_format_string_with_n() { + size_t fmt_size = fmt_cur - fmt_start + 1; + // +2 for %n, +1 for \0 + char *new_fmt = (char *)malloc(fmt_size + 2 + 1); + assert(new_fmt); + internal_memcpy(new_fmt, fmt_start, fmt_size); + new_fmt[fmt_size] = '%'; + new_fmt[fmt_size + 1] = 'n'; + new_fmt[fmt_size + 2] = '\0'; + return new_fmt; + } + + char *str_cur() { return str + str_off; } + + size_t num_written_bytes(int retval) { + if (retval < 0) { + return 0; + } + + size_t num_avail = str_off < size ? size - str_off : 0; + if (num_avail == 0) { + return 0; + } + + size_t num_written = retval; + // A return value of {v,}snprintf of size or more means that the output was + // truncated. + if (num_written >= num_avail) { + num_written -= num_avail; + } + + return num_written; + } + + char *str; + size_t str_off; + size_t size; + const char *fmt_start; + const char *fmt_cur; + int width; + int num_scanned; + bool skip; +}; + // This function is an inverse of format_buffer: we take the input buffer, // scan it in search for format strings and store the results in the varargs. // The labels are propagated from the input buffer to the varargs. @@ -2474,220 +2532,222 @@ static int scan_buffer(char *str, size_t size, const char *fmt, dfsan_label *va_labels, dfsan_label *ret_label, dfsan_origin *str_origin, dfsan_origin *ret_origin, va_list ap) { - Formatter formatter(str, fmt, size); - while (*formatter.fmt_cur) { - formatter.fmt_start = formatter.fmt_cur; - formatter.width = -1; - formatter.skip = false; + Scanner scanner(str, fmt, size); + while (*scanner.fmt_cur) { + scanner.fmt_start = scanner.fmt_cur; + scanner.width = -1; + scanner.skip = false; int read_count = 0; void *dst_ptr = 0; size_t write_size = 0; - if (*formatter.fmt_cur != '%') { - // Ordinary character. Consume all the characters until a '%' or the end - // of the string. - for (; *(formatter.fmt_cur + 1) && *(formatter.fmt_cur + 1) != '%'; - ++formatter.fmt_cur) { + if (*scanner.fmt_cur != '%') { + // Ordinary character and spaces. + // Consume all the characters until a '%' or the end of the string. + for (; *(scanner.fmt_cur + 1) && *(scanner.fmt_cur + 1) != '%'; + ++scanner.fmt_cur) { + } + if (scanner.check_match_ordinary() < 0) { + // The ordinary characters did not match. + break; } - read_count = formatter.scan(); - dfsan_set_label(0, formatter.str_cur(), - formatter.num_written_bytes(read_count)); } else { // Conversion directive. Consume all the characters until a conversion // specifier or the end of the string. bool end_fmt = false; - for (; *formatter.fmt_cur && !end_fmt;) { - switch (*++formatter.fmt_cur) { - case 'd': - case 'i': - case 'o': - case 'u': - case 'x': - case 'X': - if (formatter.skip) { - read_count = formatter.scan(); - } else { - switch (*(formatter.fmt_cur - 1)) { - case 'h': - // Also covers the 'hh' case (since the size of the arg is still - // an int). - dst_ptr = va_arg(ap, int *); - read_count = formatter.scan((int *)dst_ptr); - write_size = sizeof(int); - break; - case 'l': - if (formatter.fmt_cur - formatter.fmt_start >= 2 && - *(formatter.fmt_cur - 2) == 'l') { - dst_ptr = va_arg(ap, long long int *); - read_count = formatter.scan((long long int *)dst_ptr); - write_size = sizeof(long long int); - } else { - dst_ptr = va_arg(ap, long int *); - read_count = formatter.scan((long int *)dst_ptr); - write_size = sizeof(long int); + for (; *scanner.fmt_cur && !end_fmt;) { + switch (*++scanner.fmt_cur) { + case 'd': + case 'i': + case 'o': + case 'u': + case 'x': + case 'X': + if (scanner.skip) { + read_count = scanner.scan(); + } else { + switch (*(scanner.fmt_cur - 1)) { + case 'h': + // Also covers the 'hh' case (since the size of the arg is + // still an int). + dst_ptr = va_arg(ap, int *); + read_count = scanner.scan((int *)dst_ptr); + write_size = sizeof(int); + break; + case 'l': + if (scanner.fmt_cur - scanner.fmt_start >= 2 && + *(scanner.fmt_cur - 2) == 'l') { + dst_ptr = va_arg(ap, long long int *); + read_count = scanner.scan((long long int *)dst_ptr); + write_size = sizeof(long long int); + } else { + dst_ptr = va_arg(ap, long int *); + read_count = scanner.scan((long int *)dst_ptr); + write_size = sizeof(long int); + } + break; + case 'q': + dst_ptr = va_arg(ap, long long int *); + read_count = scanner.scan((long long int *)dst_ptr); + write_size = sizeof(long long int); + break; + case 'j': + dst_ptr = va_arg(ap, intmax_t *); + read_count = scanner.scan((intmax_t *)dst_ptr); + write_size = sizeof(intmax_t); + break; + case 'z': + case 't': + dst_ptr = va_arg(ap, size_t *); + read_count = scanner.scan((size_t *)dst_ptr); + write_size = sizeof(size_t); + break; + default: + dst_ptr = va_arg(ap, int *); + read_count = scanner.scan((int *)dst_ptr); + write_size = sizeof(int); + } + // get the label associated with the string at the corresponding + // place + dfsan_label l = dfsan_read_label( + scanner.str_cur(), scanner.num_written_bytes(read_count)); + dfsan_set_label(l, dst_ptr, write_size); + if (str_origin != nullptr) { + dfsan_set_label(l, dst_ptr, write_size); + size_t scan_count = scanner.num_written_bytes(read_count); + size_t size = scan_count > write_size ? write_size : scan_count; + dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size); } - break; - case 'q': - dst_ptr = va_arg(ap, long long int *); - read_count = formatter.scan((long long int *)dst_ptr); - write_size = sizeof(long long int); - break; - case 'j': - dst_ptr = va_arg(ap, intmax_t *); - read_count = formatter.scan((intmax_t *)dst_ptr); - write_size = sizeof(intmax_t); - break; - case 'z': - case 't': - dst_ptr = va_arg(ap, size_t *); - read_count = formatter.scan((size_t *)dst_ptr); - write_size = sizeof(size_t); - break; - default: - dst_ptr = va_arg(ap, int *); - read_count = formatter.scan((int *)dst_ptr); - write_size = sizeof(int); - } - // get the label associated with the string at the corresponding - // place - dfsan_label l = dfsan_read_label( - formatter.str_cur(), formatter.num_written_bytes(read_count)); - dfsan_set_label(l, dst_ptr, write_size); - if (str_origin != nullptr) { - dfsan_set_label(l, dst_ptr, write_size); - size_t scan_count = formatter.num_written_bytes(read_count); - size_t size = scan_count > write_size ? write_size : scan_count; - dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size); } - } - end_fmt = true; + end_fmt = true; - break; + break; - case 'a': - case 'A': - case 'e': - case 'E': - case 'f': - case 'F': - case 'g': - case 'G': - if (formatter.skip) { - read_count = formatter.scan(); - } else { - if (*(formatter.fmt_cur - 1) == 'L') { - dst_ptr = va_arg(ap, long double *); - read_count = formatter.scan((long double *)dst_ptr); - write_size = sizeof(long double); - } else if (*(formatter.fmt_cur - 1) == 'l') { - dst_ptr = va_arg(ap, double *); - read_count = formatter.scan((double *)dst_ptr); - write_size = sizeof(double); + case 'a': + case 'A': + case 'e': + case 'E': + case 'f': + case 'F': + case 'g': + case 'G': + if (scanner.skip) { + read_count = scanner.scan(); } else { - dst_ptr = va_arg(ap, float *); - read_count = formatter.scan((float *)dst_ptr); - write_size = sizeof(float); - } - dfsan_label l = dfsan_read_label( - formatter.str_cur(), formatter.num_written_bytes(read_count)); - dfsan_set_label(l, dst_ptr, write_size); - if (str_origin != nullptr) { - dfsan_set_label(l, dst_ptr, write_size); - size_t scan_count = formatter.num_written_bytes(read_count); - size_t size = scan_count > write_size ? write_size : scan_count; - dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size); + if (*(scanner.fmt_cur - 1) == 'L') { + dst_ptr = va_arg(ap, long double *); + read_count = scanner.scan((long double *)dst_ptr); + write_size = sizeof(long double); + } else if (*(scanner.fmt_cur - 1) == 'l') { + dst_ptr = va_arg(ap, double *); + read_count = scanner.scan((double *)dst_ptr); + write_size = sizeof(double); + } else { + dst_ptr = va_arg(ap, float *); + read_count = scanner.scan((float *)dst_ptr); + write_size = sizeof(float); + } + dfsan_label l = dfsan_read_label( + scanner.str_cur(), scanner.num_written_bytes(read_count)); + dfsan_set_label(l, dst_ptr, write_size); + if (str_origin != nullptr) { + dfsan_set_label(l, dst_ptr, write_size); + size_t scan_count = scanner.num_written_bytes(read_count); + size_t size = scan_count > write_size ? write_size : scan_count; + dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size); + } } - } - end_fmt = true; - break; + end_fmt = true; + break; - case 'c': - if (formatter.skip) { - read_count = formatter.scan(); - } else { - dst_ptr = va_arg(ap, char *); - read_count = formatter.scan((char *)dst_ptr); - write_size = sizeof(char); - dfsan_label l = dfsan_read_label( - formatter.str_cur(), formatter.num_written_bytes(read_count)); - dfsan_set_label(l, dst_ptr, write_size); - if (str_origin != nullptr) { - size_t scan_count = formatter.num_written_bytes(read_count); - size_t size = scan_count > write_size ? write_size : scan_count; - dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size); + case 'c': + if (scanner.skip) { + read_count = scanner.scan(); + } else { + dst_ptr = va_arg(ap, char *); + read_count = scanner.scan((char *)dst_ptr); + write_size = sizeof(char); + dfsan_label l = dfsan_read_label( + scanner.str_cur(), scanner.num_written_bytes(read_count)); + dfsan_set_label(l, dst_ptr, write_size); + if (str_origin != nullptr) { + size_t scan_count = scanner.num_written_bytes(read_count); + size_t size = scan_count > write_size ? write_size : scan_count; + dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size); + } } - } - end_fmt = true; - break; + end_fmt = true; + break; - case 's': { - if (formatter.skip) { - read_count = formatter.scan(); - } else { - dst_ptr = va_arg(ap, char *); - read_count = formatter.scan((char *)dst_ptr); - if (1 == read_count) { - // special case: we have parsed a single string and we need to - // update read_count with the string size - read_count = strlen((char *)dst_ptr); + case 's': { + if (scanner.skip) { + read_count = scanner.scan(); + } else { + dst_ptr = va_arg(ap, char *); + read_count = scanner.scan((char *)dst_ptr); + if (1 == read_count) { + // special case: we have parsed a single string and we need to + // update read_count with the string size + read_count = strlen((char *)dst_ptr); + } + if (str_origin) + dfsan_mem_origin_transfer( + dst_ptr, scanner.str_cur(), + scanner.num_written_bytes(read_count)); + va_labels++; + dfsan_mem_shadow_transfer(dst_ptr, scanner.str_cur(), + scanner.num_written_bytes(read_count)); } - if (str_origin) - dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), - formatter.num_written_bytes(read_count)); - va_labels++; - dfsan_mem_shadow_transfer(dst_ptr, formatter.str_cur(), - formatter.num_written_bytes(read_count)); + end_fmt = true; + break; } - end_fmt = true; - break; - } - case 'p': - if (formatter.skip) { - read_count = formatter.scan(); - } else { - dst_ptr = va_arg(ap, void *); - read_count = - formatter.scan((int *)dst_ptr); // note: changing void* to int* + case 'p': + if (scanner.skip) { + read_count = scanner.scan(); + } else { + dst_ptr = va_arg(ap, void *); + read_count = + scanner.scan((int *)dst_ptr); // note: changing void* to int* // since we need to call sizeof - write_size = sizeof(int); - - dfsan_label l = dfsan_read_label( - formatter.str_cur(), formatter.num_written_bytes(read_count)); - dfsan_set_label(l, dst_ptr, write_size); - if (str_origin != nullptr) { - dfsan_set_label(l, dst_ptr, write_size); - size_t scan_count = formatter.num_written_bytes(read_count); - size_t size = scan_count > write_size ? write_size : scan_count; - dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size); + write_size = sizeof(int); + + dfsan_label l = dfsan_read_label( + scanner.str_cur(), scanner.num_written_bytes(read_count)); + dfsan_set_label(l, dst_ptr, write_size); + if (str_origin != nullptr) { + dfsan_set_label(l, dst_ptr, write_size); + size_t scan_count = scanner.num_written_bytes(read_count); + size_t size = scan_count > write_size ? write_size : scan_count; + dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size); + } } - } - end_fmt = true; - break; + end_fmt = true; + break; - case 'n': { - if (!formatter.skip) { - int *ptr = va_arg(ap, int *); - *ptr = (int)formatter.str_off; - *va_labels++ = 0; - dfsan_set_label(0, ptr, sizeof(*ptr)); - if (str_origin != nullptr) - *str_origin++ = 0; + case 'n': { + if (!scanner.skip) { + int *ptr = va_arg(ap, int *); + *ptr = (int)scanner.str_off; + *va_labels++ = 0; + dfsan_set_label(0, ptr, sizeof(*ptr)); + if (str_origin != nullptr) + *str_origin++ = 0; + } + end_fmt = true; + break; } - end_fmt = true; - break; - } - case '%': - read_count = formatter.scan(); - end_fmt = true; - break; + case '%': + read_count = scanner.scan(); + end_fmt = true; + break; - case '*': - formatter.skip = true; - break; + case '*': + scanner.skip = true; + break; - default: - break; + default: + break; } } } @@ -2697,8 +2757,8 @@ static int scan_buffer(char *str, size_t size, const char *fmt, return read_count; } - formatter.fmt_cur++; - formatter.str_off += read_count; + scanner.fmt_cur++; + scanner.str_off += read_count; } (void)va_labels; // Silence unused-but-set-parameter warning @@ -2707,7 +2767,7 @@ static int scan_buffer(char *str, size_t size, const char *fmt, *ret_origin = 0; // Number of items scanned in total. - return formatter.num_scanned; + return scanner.num_scanned; } extern "C" { diff --git a/compiler-rt/test/dfsan/sscanf.c b/compiler-rt/test/dfsan/sscanf.c index dbc2de4ba96c1..88325642ef5e3 100644 --- a/compiler-rt/test/dfsan/sscanf.c +++ b/compiler-rt/test/dfsan/sscanf.c @@ -1,18 +1,111 @@ // RUN: %clang_dfsan %s -o %t && %run %t -// XFAIL: * #include #include int main(int argc, char *argv[]) { - char buf[256] = "10000000000-100000000000 rw-p 00000000 00:00 0"; - long rss = 0; - // This test exposes a bug in DFSan's sscanf, that leads to flakiness - // in release_shadow_space.c (see - // https://github.com/llvm/llvm-project/issues/91287) - if (sscanf(buf, "Garbage text before, %ld, Garbage text after", &rss) == 1) { - printf("Error: matched %ld\n", rss); - return 1; + { + char buf[256] = "10000000000-100000000000 rw-p 00000000 00:00 0"; + long rss = 0; + // This test exposes a bug in DFSan's sscanf, that leads to flakiness + // in release_shadow_space.c (see + // https://github.com/llvm/llvm-project/issues/91287) + int r = sscanf(buf, "Garbage text before, %ld, Garbage text after", &rss); + assert(r == 0); + } + + // Testing other variations of sscanf behavior. + { + int a = 0; + int b = 0; + int r = sscanf("abc42 cat 99", "abc%d cat %d", &a, &b); + assert(a == 42); + assert(b == 99); + assert(r == 2); + } + + { + int a = 0; + int b = 0; + int r = sscanf("abc42 cat 99", "abc%d dog %d", &a, &b); + assert(a == 42); + assert(r == 1); + } + + { + int a = 0; + int b = 0; + int r = sscanf("abx42 dog 99", "abc%d dog %d", &a, &b); + assert(r == 0); + } + + { + int r = sscanf("abx", "abc"); + assert(r == 0); + } + + { + int r = sscanf("abc", "abc"); + assert(r == 0); + } + + { + int n = 0; + int r = sscanf("abc", "abc%n", &n); + assert(n == 3); + assert(r == 0); + } + + { + int n = 1234; + int r = sscanf("abxy", "abcd%n", &n); + assert(n == 1234); + assert(r == 0); + } + + { + int a = 0; + int n = 1234; + int r = sscanf("abcd99", "abcd%d%n", &a, &n); + assert(a == 99); + assert(n == 6); + assert(r == 1); + } + + { + int n = 1234; + int r = sscanf("abcdsuffix", "abcd%n", &n); + assert(n == 4); + assert(r == 0); + } + + { + int n = 1234; + int r = sscanf("abxxsuffix", "abcd%n", &n); + assert(n == 1234); + assert(r == 0); + } + + { + int a = 0; + int b = 0; + int n = 1234; + int r = sscanf("abcd99 xy100", "abcd%d xy%d%n", &a, &b, &n); + assert(a == 99); + assert(b == 100); + assert(n == 12); + assert(r == 2); + } + + { + int a = 0; + int b = 0; + int n = 1234; + int r = sscanf("abcd99 xy100", "abcd%d zz%d%n", &a, &b, &n); + assert(a == 99); + assert(b == 0); + assert(n == 1234); + assert(r == 1); } return 0;