suppressPackageStartupMessages({
library(this.path)
library(Rcpp)
library(dplyr)
library(data.table)
library(qs2)
library(stringplus)
library(parallel)
library(xgboost)
library(caret)
library(future.apply)
})
plan("future::multisession", workers=16)

# dynamic blocksize set to 1048576
# qs2:::internal_set_blocksize(1048576)

DATA_PATH <- "~/datasets/processed"

# training datasets
datasets <- DATA_PATH & "/" & c("DC_real_estate_2024.json.gz", "dslabs_mnist.rds", 
                                "enwik8.csv.gz", "era5_land_wind_20230101.rds",
                                "GAIA_pseudocolor.csv.gz", "NYSE_1962_2024.csv.gz",
                                "recount3_gtex_heart.rds", "T_cell_ADIRP0000010.rds",
                                "pisces_2018.csv.gz", "Berkeley_grid_temp_2010.rds",
                                "Oahu_OSM.rds", "Clifford_100M.rds",
                                "methylation_450k.rds", "NYC_motor_vehicle_collisions.csv.gz",
                                "steam_games_2024.csv.gz", "twitter_sentiment140.csv.gz")

# holdout testing datasets
# datasets <- DATA_PATH & "/" & c("1000genomes_noncoding_vcf.csv.gz", "B_cell_petshopmouse3.tsv.gz",
#                                 "ip_location_2023.csv.gz", "Netflix_Ratings.csv.gz")

read_dataset <- function(d) {
  if(d %like% "json.gz") {
    DATA <- RcppSimdJson::fload(d)
  } else if(d %like% ".csv.gz") {
    DATA <- fread(d, sep = ",", data.table=F)
  } else if(d %like% ".tsv.gz") {
    DATA <- fread(d, sep = "\t", data.table=F)
  } else {
    DATA <- readRDS(d)
  }
}
outfile <- "data/blockshuffle_training_data.csv.gz"
if(!file.exists(outfile)) {
  
  libs <- system("pkg-config --libs libzstd", intern = TRUE)
  cflags <- system("pkg-config --cflags libzstd", intern = TRUE)
  Sys.setenv(PKG_CPPFLAGS = "-mavx2 %s %s" | c(cflags, libs))
  Sys.setenv(PKG_LIBS = "-mavx2 %s %s" | c(cflags, libs))
  sourceCpp("blockshuffle_heuristic.cpp", verbose=TRUE, rebuild = TRUE)
  min_shuffleblock_size <- 262144
  
  blocks_df <- lapply(datasets, function(d) {
    tmp <- tempfile()
    data <- read_dataset(d)
    dname<- basename(d) %>% gsub("\\..+", "", .)
    qs2::qd_save(data, file = tmp)
    x <- qs2::qx_dump(tmp)
    tibble(dataset = dname, blocks = x$blocks)
  }) %>% rbindlist
  blocks_df$blocksize <- sapply(blocks_df$blocks, length)
  blocks_df <- filter(blocks_df, blocksize >= min_shuffleblock_size)
  
  gc(full=TRUE)
  compress_levels <- 22:1
  results <- mclapply(compress_levels, function(cl) {
    print(cl)
    output <- shuffle_heuristic(blocks_df$blocks)
    output$no_shuffle_zblocksize <- og_compress(blocks_df$blocks, cl)$size
    output$shuffle_zblocksize <- shuffle_compress(blocks_df$blocks, 8, cl)$size
    output <- cbind(blocks_df %>% dplyr::select(blocksize, dataset), output) %>%
      mutate(compress_level = cl)
  }, mc.cores=8, mc.preschedule=FALSE) %>% rbindlist
  
  # add block index per dataset
  results <- results %>%
    group_by(dataset, compress_level) %>%
    mutate(index = 1:n())
  fwrite(results, outfile, sep = ",")
} else {
  results <- fread(outfile, data.table=FALSE)
}
datasets <- datasets %>% basename %>% gsub("\\..+$", "", .)

# prediction falls off after CL 16
# so if user gives CL 16+ we wont rely on this heuristic and will instead try compression both ways
data <- results %>% 
  filter(compress_level <= 22) %>%
  mutate(improvement = log(no_shuffle_zblocksize / shuffle_zblocksize) ) %>%
  group_by(dataset) %>%
  mutate(weight = 1/n()) %>%
  ungroup %>%
  mutate(weight = weight / mean(weight)) %>% # use mean(weight) == 1
  as.data.frame

set.seed(314156)
train_index <- list()
valid_index <- list()
for(d in datasets) {
  train_index[[d]] <- filter(data, dataset == d) %>%
    pull(index) %>%
    unique %>% sample(0.8 * length(.)) %>% sort
  valid_index[[d]] <- filter(data, dataset == d) %>%
    pull(index) %>%
    unique %>% setdiff(train_index[[d]]) %>% sort
}

train_data <- lapply(datasets, function(d) {
  filter(data, dataset == d) %>%
    filter(index %in% train_index[[d]])
}) %>% rbindlist

valid_data <- lapply(datasets, function(d) {
  filter(data, dataset == d) %>%
    filter(index %in% valid_index[[d]])
}) %>% rbindlist

splits <- vector("list", 5)
for(d in datasets) {
  td <- train_data %>% filter(dataset == d)
  x <- td %>% pull(index) %>% 
    unique %>% sample
  x <- split(x, rep_len(1:5, length(x)))
  for(i in 1:5) {
    # get row index in train_data and append to splits
    splits[[i]] <- c(splits[[i]], which(train_data$index %in% x[[i]] & train_data$dataset == d))
  }
}
outfile <- "data/blockshuffle_param_search_cv.csv.gz"

if(!file.exists(outfile)) {
  param_grid <- expand.grid(
    max_depth = c(4, 6, 8),
    eta = c(0.01, 0.1, 0.3),
    colsample_bytree = c(0.5, 0.7, 1.0),
    min_child_weight = c(1, 5, 10),
    subsample = c(0.5, 0.7, 1.0)
  )
  
  scores <- future_lapply(1:nrow(param_grid), function(q) {
    dtrain <- xgb.DMatrix(data = train_data %>% 
                            dplyr::select(-dataset, -index, -improvement, -no_shuffle_zblocksize, -shuffle_zblocksize, -time, -blocksize, -weight) %>% 
                            data.matrix, label = train_data$improvement, weight = train_data$weight)
    
    params <- list(
      objective = "reg:squarederror",  # Regression task
      max_depth = param_grid$max_depth[q],  # Maximum depth of a tree
      eta = param_grid$eta[q],         # Learning rate
      colsample_bytree = param_grid$colsample_bytree[q], # Subsample ratio of columns
      min_child_weight = param_grid$min_child_weight[q], # Minimum sum of instance weight(hessian) needed in a child
      subsample = param_grid$subsample[q], # Subsample ratio of the training instance
      nthread = 1)                      # Number of parallel threads
    
    bcv <- xgb.cv(
      params = params,
      data = dtrain,
      nrounds = 1000,
      showsd = FALSE,
      metrics = list("rmse"),
      folds = splits,
      verbose = TRUE,
      print_every_n = 10,
      early_stopping_rounds = 10)
    
    cbind(param_grid[q,], bcv$evaluation_log, row.names=NULL)
  }, future.seed=TRUE, future.globals = c("param_grid", "train_data", "splits"), future.packages = c("xgboost", "dplyr"))
  
  scores <- rbindlist(scores)
  fwrite(scores, file = outfile, sep = ",")
} else {
  scores <- fread(outfile, data.table=FALSE)
}
# evaluate parameter range
best_scores <- scores %>% 
  group_by(max_depth, eta, colsample_bytree, min_child_weight, subsample) %>% 
  filter(test_rmse_mean == min(test_rmse_mean)) %>% ungroup

ggplot(best_scores, aes(x=iter, y = test_rmse_mean)) +
  geom_point(color = "red") + 
  geom_errorbar(aes(ymin = test_rmse_mean - test_rmse_std, ymax = test_rmse_mean + test_rmse_std)) +
  theme_bw(base_size = 12) + 
  labs(subtitle = "best score by # of iterations")

best_params <- filter(best_scores, iter < 750) %>% # limit size of model
  filter(test_rmse_mean == min(test_rmse_mean))
print(best_params)
## # A tibble: 1 × 10
##   max_depth   eta colsample_bytree min_child_weight subsample  iter
##       <int> <dbl>            <dbl>            <int>     <dbl> <int>
## 1         8   0.1              0.7               10       0.5   372
## # ℹ 4 more variables: train_rmse_mean <dbl>, train_rmse_std <dbl>,
## #   test_rmse_mean <dbl>, test_rmse_std <dbl>
dtrain <- xgb.DMatrix(data = train_data %>% 
                        dplyr::select(-dataset, -index, -improvement, -no_shuffle_zblocksize, -shuffle_zblocksize, -time, -blocksize, -weight) %>% 
                        data.matrix, label = train_data$improvement, weight = train_data$weight)

dvalid <- xgb.DMatrix(data = valid_data %>% 
                        dplyr::select(-dataset, -index, -improvement, -no_shuffle_zblocksize, -shuffle_zblocksize, -time, -blocksize, -weight) %>% 
                        data.matrix, label = valid_data$improvement, weight = valid_data$weight)

params <- list(
  objective = "reg:squarederror",  # Regression task
  max_depth = best_params$max_depth,                   # Maximum depth of a tree
  eta = best_params$eta,         # Learning rate
  colsample_bytree = best_params$colsample_bytree, # Subsample ratio of columns
  min_child_weight = best_params$min_child_weight, # Minimum sum of instance weight(hessian) needed in a child
  subsample = best_params$subsample, # Subsample ratio of the training instance
  nthread = 1)                      # Number of parallel threads

bst <- xgb.train(
  params = params,
  data = dtrain,
  nrounds = 1000,                   # Number of boosting rounds
  watchlist = list(train = dtrain, test = dvalid),  # For tracking performance
  print_every_n = 10,              # Print progress every 10 rounds
  early_stopping_rounds = 10)       # Stop early if no improvement for 10 rounds)
## [1]  train-rmse:0.953713 test-rmse:0.923433 
## Multiple eval metrics are present. Will use test_rmse for early stopping.
## Will train until test_rmse hasn't improved in 10 rounds.
## 
## [11] train-rmse:0.368739 test-rmse:0.376905 
## [21] train-rmse:0.178799 test-rmse:0.213039 
## [31] train-rmse:0.119414 test-rmse:0.169921 
## [41] train-rmse:0.098785 test-rmse:0.156789 
## [51] train-rmse:0.088646 test-rmse:0.152115 
## [61] train-rmse:0.082118 test-rmse:0.149188 
## [71] train-rmse:0.076336 test-rmse:0.146810 
## [81] train-rmse:0.072756 test-rmse:0.145407 
## [91] train-rmse:0.068961 test-rmse:0.144291 
## [101]    train-rmse:0.066483 test-rmse:0.143016 
## [111]    train-rmse:0.063931 test-rmse:0.142367 
## [121]    train-rmse:0.062276 test-rmse:0.141932 
## [131]    train-rmse:0.059729 test-rmse:0.141412 
## [141]    train-rmse:0.057968 test-rmse:0.141196 
## [151]    train-rmse:0.056488 test-rmse:0.140988 
## [161]    train-rmse:0.055480 test-rmse:0.140662 
## [171]    train-rmse:0.054347 test-rmse:0.140185 
## [181]    train-rmse:0.053191 test-rmse:0.139993 
## [191]    train-rmse:0.051716 test-rmse:0.139800 
## [201]    train-rmse:0.050563 test-rmse:0.139899 
## Stopping. Best iteration:
## [193]    train-rmse:0.051423 test-rmse:0.139746
# save model
xgboost::xgb.save(bst, fname = "data/blockshuffle_xgboost_model.json")
## [1] TRUE
# plot validation data by compress_level

valid_data <- valid_data %>% mutate(prediction = predict(bst, dvalid))

pal <- palette.colors(palette = "Okabe-Ito") %>% rep(length.out=length(datasets))
shp <- ifelse(duplicated(pal), 21, 24)
ggplot(valid_data, aes(x = prediction, y = improvement, color = dataset, shape = dataset)) + 
  geom_abline(aes(slope=1, intercept = 0), lty = 2) + 
  geom_vline(aes(xintercept=0), lty=2, color = "orange") +
  geom_point(alpha=0.75) + 
  facet_wrap(~compress_level, ncol=4) +
  scale_color_manual(values = pal) +
  scale_shape_manual(values = shp) +
  theme_bw(base_size = 12)