Classification (family = 'cls')¶
In this module, we consider models of each source domain are any machine learning models. And by using a cross-entrpy loss in CGDRO Learner, we aggregate the sources into the target domain. For more details of methods, please refer CGDRO-Classification.
We can use cgdro_() with family = 'cls' for classifications.
Example¶
Data Generating Process¶
In this example, we generate a multi-source domain data with $2$ domains, putting $100$ samples on each source domain and $1,000$ samples on the target domain. We consider a multi-class classification problem with $C=K+1=3$ labels. The dimension of the parameters is $p=5$.
# two source groups, each with 100 samples, and 1000 target samples
n = 100; p = 5; L = 2; N = 1000; K = 2
data <- simu_cls(n, N, p, L, K, seed=123)
Xlist = data$X_list
Ylist = data$Y_list
X0 = data$X0
Implementation & Results¶
## fit cgdro
## using linear regression as f_learner and logistic regression as w_learner
fit <- cgdro_(Xlist, ylist, X0,
family = "cls", f_learner = "linear", w_learner = "logistic")
inf <- infer_cgdro_(fit, M = 200, alpha = 0.05, parallel = TRUE, n_workers = 4, diag = TRUE)
summary_cgdro_(fit, infer = inf)
Model Summary ================================= CGDRO Aggregated Weights: group | 1 2 weight_ | 0.3288 0.6712 ================================= CGDRO Aggregated Estimators: Class 1 coefficients: index | 1 2 3 4 5 coef_ | 0.0788 0.2761 0.5442 -0.6597 0.2892 Class 2 coefficients: index | 1 2 3 4 5 coef_ | -0.0427 0.2596 0.5894 -0.3600 0.0625 ================================= Confidence Intervals: Class 1 Confidence Intervals: index | 1 2 3 4 5 CIs | (-1.449,1.036) (-1.425,1.752) (-1.004,2.075) (-2.118,0.543) (-0.978,1.061) Class 2 Confidence Intervals: index | 1 2 3 4 5 CIs | (-1.573,0.814) (-1.237,1.486) (-0.563,1.899) (-2.315,0.835) (-1.349,1.039)
## get inference results for coefficient index 1 and 3, for class index 2
summary_cgdro_(fit, infer = inf, index = c(1,3), class_index = c(2))
Model Summary ================================= CGDRO Aggregated Weights: group | 1 2 weight_ | 0.3288 0.6712 ================================= CGDRO Aggregated Estimators: Class 2 coefficients: index | 1 3 coef_ | -0.0427 0.5894 ================================= Confidence Intervals: Class 2 Confidence Intervals: index | 1 3 CIs | (-1.573,0.814) (-0.563,1.899)
We can get statistical inference results from CGDRO, including CGDRO Aggregated Weights (learned weights from each group of source domain), Coefficient Estimators (the worst-case estimators of coefficient on target domain), and Confidence Intervals (valid confidence intervals of target domain coefficient estimators). In the summarized results above, group refers to each group of source domains, index refers to the index of coeffients, starting from the intercept if intercept=TRUE, else starting from the first dimension of coefficient, and Class start from $1$ to $K=C-1$.
Prediction¶
Make prediction on target data (you do not have to state the coveriate you use for prediction since target data is the default choice) and show the first 6 predicted values of softmax probabilities and labels.
pred <- predict_cgdro_(fit) # N x C matrix of predicted probabilities
head(pred$pred_proba)
| 0.3410477 | 0.3179046 | 0.3410477 |
| 0.3565226 | 0.3565226 | 0.2869548 |
| 0.3658669 | 0.3658669 | 0.2682663 |
| 0.3429900 | 0.3429900 | 0.3140200 |
| 0.3698810 | 0.2602380 | 0.3698810 |
| 0.3357761 | 0.3357761 | 0.3284478 |
head(pred$pred)
- 0
- 1
- 0
- 1
- 2
- 1