8
8
# ' a \code{dgRMatrix} object,
9
9
# ' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
10
10
# ' interpreted as a row vector), or a character string representing a filename.
11
- # ' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
12
- # ' See \code{\link{setinfo}} for the specific allowed kinds of
11
+ # ' @param label Label of the training data.
12
+ # ' @param weight Weight for each instance.
13
+ # '
14
+ # ' Note that, for ranking task, weights are per-group. In ranking task, one weight
15
+ # ' is assigned to each group (not each data point). This is because we
16
+ # ' only care about the relative ordering of data points within each group,
17
+ # ' so it doesn't make sense to assign weights to individual data points.
18
+ # ' @param base_margin Base margin used for boosting from existing model.
13
19
# ' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
14
20
# ' It is useful when a 0 or some other extreme value represents missing values in data.
15
21
# ' @param silent whether to suppress printing an informational message after loading from a file.
22
+ # ' @param feature_names Set names for features.
16
23
# ' @param nthread Number of threads used for creating DMatrix.
17
- # ' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
24
+ # ' @param group Group size for all ranking group.
25
+ # ' @param qid Query ID for data samples, used for ranking.
26
+ # ' @param label_lower_bound Lower bound for survival training.
27
+ # ' @param label_upper_bound Upper bound for survival training.
28
+ # ' @param feature_weights Set feature weights for column sampling.
18
29
# '
19
30
# ' @details
20
31
# ' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
34
45
# ' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
35
46
# ' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
36
47
# ' @export
37
- xgb.DMatrix <- function (data , info = list (), missing = NA , silent = FALSE , nthread = NULL , ... ) {
38
- cnames <- NULL
48
+ xgb.DMatrix <- function (
49
+ data ,
50
+ label = NULL ,
51
+ weight = NULL ,
52
+ base_margin = NULL ,
53
+ missing = NA ,
54
+ silent = FALSE ,
55
+ feature_names = colnames(data ),
56
+ nthread = NULL ,
57
+ group = NULL ,
58
+ qid = NULL ,
59
+ label_lower_bound = NULL ,
60
+ label_upper_bound = NULL ,
61
+ feature_weights = NULL
62
+ ) {
63
+ if (! is.null(group ) && ! is.null(qid )) {
64
+ stop(" Either one of 'group' or 'qid' should be NULL" )
65
+ }
39
66
if (typeof(data ) == " character" ) {
40
67
if (length(data ) > 1 )
41
68
stop(" 'data' has class 'character' and length " , length(data ),
@@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
44
71
handle <- .Call(XGDMatrixCreateFromFile_R , data , as.integer(silent ))
45
72
} else if (is.matrix(data )) {
46
73
handle <- .Call(XGDMatrixCreateFromMat_R , data , missing , as.integer(NVL(nthread , - 1 )))
47
- cnames <- colnames(data )
48
74
} else if (inherits(data , " dgCMatrix" )) {
49
75
handle <- .Call(
50
76
XGDMatrixCreateFromCSC_R ,
@@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
55
81
missing ,
56
82
as.integer(NVL(nthread , - 1 ))
57
83
)
58
- cnames <- colnames(data )
59
84
} else if (inherits(data , " dgRMatrix" )) {
60
85
handle <- .Call(
61
86
XGDMatrixCreateFromCSR_R ,
@@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
66
91
missing ,
67
92
as.integer(NVL(nthread , - 1 ))
68
93
)
69
- cnames <- colnames(data )
70
94
} else if (inherits(data , " dsparseVector" )) {
71
95
indptr <- c(0L , as.integer(length(data @ i )))
72
96
ind <- as.integer(data @ i ) - 1L
@@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
82
106
} else {
83
107
stop(" xgb.DMatrix does not support construction from " , typeof(data ))
84
108
}
109
+
85
110
dmat <- handle
86
111
attributes(dmat ) <- list (class = " xgb.DMatrix" )
87
- if (! is.null(cnames )) {
88
- setinfo(dmat , " feature_name" , cnames )
89
- }
90
112
91
- info <- append(info , list (... ))
92
- for (i in seq_along(info )) {
93
- p <- info [i ]
94
- setinfo(dmat , names(p ), p [[1 ]])
113
+ if (! is.null(label )) {
114
+ setinfo(dmat , " label" , label )
115
+ }
116
+ if (! is.null(weight )) {
117
+ setinfo(dmat , " weight" , weight )
118
+ }
119
+ if (! is.null(base_margin )) {
120
+ setinfo(dmat , " base_margin" , base_margin )
121
+ }
122
+ if (! is.null(feature_names )) {
123
+ setinfo(dmat , " feature_name" , feature_names )
124
+ }
125
+ if (! is.null(group )) {
126
+ setinfo(dmat , " group" , group )
127
+ }
128
+ if (! is.null(qid )) {
129
+ setinfo(dmat , " qid" , qid )
95
130
}
131
+ if (! is.null(label_lower_bound )) {
132
+ setinfo(dmat , " label_lower_bound" , label_lower_bound )
133
+ }
134
+ if (! is.null(label_upper_bound )) {
135
+ setinfo(dmat , " label_upper_bound" , label_upper_bound )
136
+ }
137
+ if (! is.null(feature_weights )) {
138
+ setinfo(dmat , " feature_weights" , feature_weights )
139
+ }
140
+
96
141
return (dmat )
97
142
}
98
143
@@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) {
211
256
# ' The \code{name} field can be one of the following:
212
257
# '
213
258
# ' \itemize{
214
- # ' \item \code{label}: label XGBoost learn from ;
215
- # ' \item \code{weight}: to do a weight rescale ;
216
- # ' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
217
- # ' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
218
- # '
259
+ # ' \item \code{label}
260
+ # ' \item \code{weight}
261
+ # ' \item \code{base_margin}
262
+ # ' \item \code{label_lower_bound}
263
+ # ' \item \code{label_upper_bound}
264
+ # ' \item \code{group}
265
+ # ' \item \code{feature_type}
266
+ # ' \item \code{feature_name}
267
+ # ' \item \code{nrow}
219
268
# ' }
269
+ # ' See the documentation for \link{xgb.DMatrix} for more information about these fields.
220
270
# '
221
- # ' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
271
+ # ' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
272
+ # ' for a DMatrix that had 'qid' assigned.
222
273
# '
223
274
# ' @examples
224
275
# ' data(agaricus.train, package='xgboost')
@@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo")
236
287
# ' @rdname getinfo
237
288
# ' @export
238
289
getinfo.xgb.DMatrix <- function (object , name , ... ) {
290
+ allowed_int_fields <- ' group'
291
+ allowed_float_fields <- c(
292
+ ' label' , ' weight' , ' base_margin' ,
293
+ ' label_lower_bound' , ' label_upper_bound'
294
+ )
295
+ allowed_str_fields <- c(" feature_type" , " feature_name" )
296
+ allowed_fields <- c(allowed_float_fields , allowed_int_fields , allowed_str_fields , ' nrow' )
297
+
239
298
if (typeof(name ) != " character" ||
240
299
length(name ) != 1 ||
241
- ! name %in% c(' label' , ' weight' , ' base_margin' , ' nrow' ,
242
- ' label_lower_bound' , ' label_upper_bound' , " feature_type" , " feature_name" )) {
243
- stop(
244
- " getinfo: name must be one of the following\n " ,
245
- " 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
246
- )
300
+ ! name %in% allowed_fields ) {
301
+ stop(" getinfo: name must be one of the following\n " ,
302
+ paste(paste0(" '" , allowed_fields , " '" ), collapse = " , " ))
247
303
}
248
- if (name == " feature_name" || name == " feature_type" ) {
304
+ if (name == " nrow" ) {
305
+ ret <- nrow(object )
306
+ } else if (name %in% allowed_str_fields ) {
249
307
ret <- .Call(XGDMatrixGetStrFeatureInfo_R , object , name )
250
- } else if (name != " nrow" ) {
251
- ret <- .Call(XGDMatrixGetInfo_R , object , name )
308
+ } else if (name %in% allowed_float_fields ) {
309
+ ret <- .Call(XGDMatrixGetFloatInfo_R , object , name )
310
+ if (length(ret ) > nrow(object )) {
311
+ ret <- matrix (ret , nrow = nrow(object ), byrow = TRUE )
312
+ }
313
+ } else if (name %in% allowed_int_fields ) {
314
+ if (name == " group" ) {
315
+ name <- " group_ptr"
316
+ }
317
+ ret <- .Call(XGDMatrixGetUIntInfo_R , object , name )
252
318
if (length(ret ) > nrow(object )) {
253
319
ret <- matrix (ret , nrow = nrow(object ), byrow = TRUE )
254
320
}
255
- } else {
256
- ret <- nrow(object )
257
321
}
258
322
if (length(ret ) == 0 ) return (NULL )
259
323
return (ret )
@@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
270
334
# ' @param ... other parameters
271
335
# '
272
336
# ' @details
273
- # ' The \code{name} field can be one of the following:
274
- # '
275
- # ' \itemize{
276
- # ' \item \code{label}: label XGBoost learn from ;
277
- # ' \item \code{weight}: to do a weight rescale ;
278
- # ' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
279
- # ' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
337
+ # ' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
338
+ # ' (which correspond to arguments in that function).
339
+ # '
340
+ # ' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
341
+ # ' but \bold{aren't} allowed here:\itemize{
342
+ # ' \item data
343
+ # ' \item missing
344
+ # ' \item silent
345
+ # ' \item nthread
280
346
# ' }
281
347
# '
282
348
# ' @examples
@@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
328
394
.Call(XGDMatrixSetInfo_R , object , name , as.integer(info ))
329
395
return (TRUE )
330
396
}
397
+ if (name == " qid" ) {
398
+ if (NROW(info ) != nrow(object ))
399
+ stop(" The length of qid assignments must equal to the number of rows in the input data" )
400
+ .Call(XGDMatrixSetInfo_R , object , name , as.integer(info ))
401
+ return (TRUE )
402
+ }
331
403
if (name == " feature_weights" ) {
332
404
if (length(info ) != ncol(object )) {
333
405
stop(" The number of feature weights must equal to the number of columns in the input data" )
0 commit comments