Skip to content

Commit 5623521

Browse files
authored
[R] Move all DMatrix fields to function arguments (#9862)
1 parent 1094d60 commit 5623521

File tree

10 files changed

+237
-69
lines changed

10 files changed

+237
-69
lines changed

R-package/R/xgb.DMatrix.R

Lines changed: 111 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,24 @@
88
#' a \code{dgRMatrix} object,
99
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
1010
#' interpreted as a row vector), or a character string representing a filename.
11-
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
12-
#' See \code{\link{setinfo}} for the specific allowed kinds of
11+
#' @param label Label of the training data.
12+
#' @param weight Weight for each instance.
13+
#'
14+
#' Note that, for ranking task, weights are per-group. In ranking task, one weight
15+
#' is assigned to each group (not each data point). This is because we
16+
#' only care about the relative ordering of data points within each group,
17+
#' so it doesn't make sense to assign weights to individual data points.
18+
#' @param base_margin Base margin used for boosting from existing model.
1319
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
1420
#' It is useful when a 0 or some other extreme value represents missing values in data.
1521
#' @param silent whether to suppress printing an informational message after loading from a file.
22+
#' @param feature_names Set names for features.
1623
#' @param nthread Number of threads used for creating DMatrix.
17-
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
24+
#' @param group Group size for all ranking group.
25+
#' @param qid Query ID for data samples, used for ranking.
26+
#' @param label_lower_bound Lower bound for survival training.
27+
#' @param label_upper_bound Upper bound for survival training.
28+
#' @param feature_weights Set feature weights for column sampling.
1829
#'
1930
#' @details
2031
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
@@ -34,8 +45,24 @@
3445
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
3546
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
3647
#' @export
37-
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
38-
cnames <- NULL
48+
xgb.DMatrix <- function(
49+
data,
50+
label = NULL,
51+
weight = NULL,
52+
base_margin = NULL,
53+
missing = NA,
54+
silent = FALSE,
55+
feature_names = colnames(data),
56+
nthread = NULL,
57+
group = NULL,
58+
qid = NULL,
59+
label_lower_bound = NULL,
60+
label_upper_bound = NULL,
61+
feature_weights = NULL
62+
) {
63+
if (!is.null(group) && !is.null(qid)) {
64+
stop("Either one of 'group' or 'qid' should be NULL")
65+
}
3966
if (typeof(data) == "character") {
4067
if (length(data) > 1)
4168
stop("'data' has class 'character' and length ", length(data),
@@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
4471
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
4572
} else if (is.matrix(data)) {
4673
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
47-
cnames <- colnames(data)
4874
} else if (inherits(data, "dgCMatrix")) {
4975
handle <- .Call(
5076
XGDMatrixCreateFromCSC_R,
@@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
5581
missing,
5682
as.integer(NVL(nthread, -1))
5783
)
58-
cnames <- colnames(data)
5984
} else if (inherits(data, "dgRMatrix")) {
6085
handle <- .Call(
6186
XGDMatrixCreateFromCSR_R,
@@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
6691
missing,
6792
as.integer(NVL(nthread, -1))
6893
)
69-
cnames <- colnames(data)
7094
} else if (inherits(data, "dsparseVector")) {
7195
indptr <- c(0L, as.integer(length(data@i)))
7296
ind <- as.integer(data@i) - 1L
@@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
82106
} else {
83107
stop("xgb.DMatrix does not support construction from ", typeof(data))
84108
}
109+
85110
dmat <- handle
86111
attributes(dmat) <- list(class = "xgb.DMatrix")
87-
if (!is.null(cnames)) {
88-
setinfo(dmat, "feature_name", cnames)
89-
}
90112

91-
info <- append(info, list(...))
92-
for (i in seq_along(info)) {
93-
p <- info[i]
94-
setinfo(dmat, names(p), p[[1]])
113+
if (!is.null(label)) {
114+
setinfo(dmat, "label", label)
115+
}
116+
if (!is.null(weight)) {
117+
setinfo(dmat, "weight", weight)
118+
}
119+
if (!is.null(base_margin)) {
120+
setinfo(dmat, "base_margin", base_margin)
121+
}
122+
if (!is.null(feature_names)) {
123+
setinfo(dmat, "feature_name", feature_names)
124+
}
125+
if (!is.null(group)) {
126+
setinfo(dmat, "group", group)
127+
}
128+
if (!is.null(qid)) {
129+
setinfo(dmat, "qid", qid)
95130
}
131+
if (!is.null(label_lower_bound)) {
132+
setinfo(dmat, "label_lower_bound", label_lower_bound)
133+
}
134+
if (!is.null(label_upper_bound)) {
135+
setinfo(dmat, "label_upper_bound", label_upper_bound)
136+
}
137+
if (!is.null(feature_weights)) {
138+
setinfo(dmat, "feature_weights", feature_weights)
139+
}
140+
96141
return(dmat)
97142
}
98143

@@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) {
211256
#' The \code{name} field can be one of the following:
212257
#'
213258
#' \itemize{
214-
#' \item \code{label}: label XGBoost learn from ;
215-
#' \item \code{weight}: to do a weight rescale ;
216-
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
217-
#' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
218-
#'
259+
#' \item \code{label}
260+
#' \item \code{weight}
261+
#' \item \code{base_margin}
262+
#' \item \code{label_lower_bound}
263+
#' \item \code{label_upper_bound}
264+
#' \item \code{group}
265+
#' \item \code{feature_type}
266+
#' \item \code{feature_name}
267+
#' \item \code{nrow}
219268
#' }
269+
#' See the documentation for \link{xgb.DMatrix} for more information about these fields.
220270
#'
221-
#' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
271+
#' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
272+
#' for a DMatrix that had 'qid' assigned.
222273
#'
223274
#' @examples
224275
#' data(agaricus.train, package='xgboost')
@@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo")
236287
#' @rdname getinfo
237288
#' @export
238289
getinfo.xgb.DMatrix <- function(object, name, ...) {
290+
allowed_int_fields <- 'group'
291+
allowed_float_fields <- c(
292+
'label', 'weight', 'base_margin',
293+
'label_lower_bound', 'label_upper_bound'
294+
)
295+
allowed_str_fields <- c("feature_type", "feature_name")
296+
allowed_fields <- c(allowed_float_fields, allowed_int_fields, allowed_str_fields, 'nrow')
297+
239298
if (typeof(name) != "character" ||
240299
length(name) != 1 ||
241-
!name %in% c('label', 'weight', 'base_margin', 'nrow',
242-
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
243-
stop(
244-
"getinfo: name must be one of the following\n",
245-
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
246-
)
300+
!name %in% allowed_fields) {
301+
stop("getinfo: name must be one of the following\n",
302+
paste(paste0("'", allowed_fields, "'"), collapse = ", "))
247303
}
248-
if (name == "feature_name" || name == "feature_type") {
304+
if (name == "nrow") {
305+
ret <- nrow(object)
306+
} else if (name %in% allowed_str_fields) {
249307
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
250-
} else if (name != "nrow") {
251-
ret <- .Call(XGDMatrixGetInfo_R, object, name)
308+
} else if (name %in% allowed_float_fields) {
309+
ret <- .Call(XGDMatrixGetFloatInfo_R, object, name)
310+
if (length(ret) > nrow(object)) {
311+
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
312+
}
313+
} else if (name %in% allowed_int_fields) {
314+
if (name == "group") {
315+
name <- "group_ptr"
316+
}
317+
ret <- .Call(XGDMatrixGetUIntInfo_R, object, name)
252318
if (length(ret) > nrow(object)) {
253319
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
254320
}
255-
} else {
256-
ret <- nrow(object)
257321
}
258322
if (length(ret) == 0) return(NULL)
259323
return(ret)
@@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
270334
#' @param ... other parameters
271335
#'
272336
#' @details
273-
#' The \code{name} field can be one of the following:
274-
#'
275-
#' \itemize{
276-
#' \item \code{label}: label XGBoost learn from ;
277-
#' \item \code{weight}: to do a weight rescale ;
278-
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
279-
#' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
337+
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
338+
#' (which correspond to arguments in that function).
339+
#'
340+
#' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
341+
#' but \bold{aren't} allowed here:\itemize{
342+
#' \item data
343+
#' \item missing
344+
#' \item silent
345+
#' \item nthread
280346
#' }
281347
#'
282348
#' @examples
@@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
328394
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
329395
return(TRUE)
330396
}
397+
if (name == "qid") {
398+
if (NROW(info) != nrow(object))
399+
stop("The length of qid assignments must equal to the number of rows in the input data")
400+
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
401+
return(TRUE)
402+
}
331403
if (name == "feature_weights") {
332404
if (length(info) != ncol(object)) {
333405
stop("The number of feature weights must equal to the number of columns in the input data")

R-package/man/getinfo.Rd

Lines changed: 12 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/man/setinfo.Rd

Lines changed: 8 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/man/xgb.DMatrix.Rd

Lines changed: 30 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/src/init.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
3939
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
4040
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
4141
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
42-
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
42+
extern SEXP XGDMatrixGetFloatInfo_R(SEXP, SEXP);
43+
extern SEXP XGDMatrixGetUIntInfo_R(SEXP, SEXP);
4344
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
4445
extern SEXP XGDMatrixNumCol_R(SEXP);
4546
extern SEXP XGDMatrixNumRow_R(SEXP);
@@ -76,7 +77,8 @@ static const R_CallMethodDef CallEntries[] = {
7677
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
7778
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
7879
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
79-
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
80+
{"XGDMatrixGetFloatInfo_R", (DL_FUNC) &XGDMatrixGetFloatInfo_R, 2},
81+
{"XGDMatrixGetUIntInfo_R", (DL_FUNC) &XGDMatrixGetUIntInfo_R, 2},
8082
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
8183
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
8284
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},

0 commit comments

Comments
 (0)