diff --git a/R/csv.R b/R/csv.R index 8ac91393..a091fc19 100644 --- a/R/csv.R +++ b/R/csv.R @@ -697,7 +697,13 @@ read_csv_metadata <- function(csv_file) { for (line in metadata[[1]]) { if (!startsWith(line, "#") && is.null(csv_file_info[["variables"]])) { # if no # at the start of line, the line is the CSV header - all_names <- strsplit(line, ",")[[1]] + header_dt <- data.table::fread( + text = line, + header = FALSE, + check.names = FALSE, + data.table = FALSE + ) + all_names <- as.character(header_dt[1, ]) if (all(csv_file_info$algorithm != "fixed_param")) { csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")] csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))] @@ -915,8 +921,10 @@ check_csv_metadata_matches <- function(csv_metadata) { NULL } -# convert names like beta.1.1 to beta[1,1] +# convert names like beta.1.1 or beta(1,1) to beta[1,1] repair_variable_names <- function(names) { + names <- sub("\\(", "[", names) + names <- gsub("\\)", "", names) names <- sub("\\.", "[", names) names <- gsub("\\.", ",", names) names[grep("\\[", names)] <- @@ -983,7 +991,9 @@ variable_dims <- function(variable_names = NULL) { uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names)) var_names <- gsub("\\]", "", variable_names) for (var in uniq_variable_names) { - pattern <- paste0("^", var, "\\[") + # escape regex symbols + esc_var <- gsub("([][{}()+*?.^$|\\\\])", "\\\\\\1", var) + pattern <- paste0("^", esc_var, "\\[") var_indices <- var_names[grep(pattern, var_names)] var_indices <- gsub(pattern, "", var_indices) if (length(var_indices)) { diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index dc678a6b..827c3cc9 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -95,6 +95,19 @@ test_that("read_cmdstan_csv() fails with the no params listed", { "Supplied CSV file does not contain any variable names or data!") }) +test_that("variable_dims works for standard Stan names", { + vars <- c("beta[1]", "beta[2]", "beta[3]", "sigma") + dims <- variable_dims(vars) + expect_equal(dims$beta, 3) + expect_equal(dims$sigma, 1) +}) + +test_that("variable_dims handles names with regex metacharacters", { + vars <- c('SIGMA(1,1)[1,2]', 'SIGMA(1,1)[2,2]') + dims <- variable_dims(vars) + expect_equal(dims[["SIGMA(1,1)"]], c(2, 2)) +}) + test_that("read_cmdstan_csv() matches utils::read.csv", { csv_files <- c(test_path("resources", "csv", "model1-1-warmup.csv"), test_path("resources", "csv", "model1-2-warmup.csv")) @@ -896,3 +909,56 @@ test_that("read_cmdstan_csv() works with tilde expansion", { tildified_path <- file.path("~", fs::path_rel(full_path, "~")) expect_no_error(read_cmdstan_csv(tildified_path)) }) + +test_that("as_cmdstan_fit handles parameter names with parentheses and indices", { + skip_on_cran() + + csv_file <- withr::local_tempfile(fileext = ".csv") + lines <- c( + "# model = norm_model", + "# method = sample (Default)", + "# id = 1", + "# thin = 1", + "# save_warmup = 0", + 'lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,"Sigma(1,1)","Sigma(1,2)","Sigma(2,1)","Sigma(2,2)"', + "-65.951579,0.92571393,0.77752825,3,7,0,67.391073,0.2808549,-0.95718644,0.080662461,0.58814086", + "-66.417297,0.89632515,0.77752825,2,3,0,68.026905,0.3014893,-0.97834703,0.069719538,0.89573157" + ) + writeLines(lines, csv_file) + + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE) + + vars <- posterior::variables(fit$draws()) + expect_true(all( + c("Sigma[1,1]", "Sigma[1,2]", "Sigma[2,1]", "Sigma[2,2]") %in% vars + )) + + dims <- fit$metadata()$stan_variable_sizes + expect_equal(dims[["Sigma"]], c(2, 2)) +}) + +test_that("as_cmdstan_fit handles variable names with parentheses", { + skip_on_cran() + csv_file <- withr::local_tempfile(fileext = ".csv") + writeLines(c( + "# model = norm_model", + "# method = sample (Default)", + "# id = 1", + "# thin = 1", + "# save_warmup = 0", + "THETA4,SIGMA(1,1)", + "2.00000E+00,2.00000E+00", + "2.00000E+00,2.00000E+00" + ), con = csv_file) + + expect_no_error({ + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE, format = "draws_matrix") + }) + + draws <- fit$draws() + vars <- posterior::variables(draws) + + expect_equal(posterior::ndraws(draws), 2L) + expect_true(any(grepl("THETA4", vars))) + expect_true(any(grepl("SIGMA", vars))) +})