From 2fb7e0faaa1b4e9aaf800ee03140be7ad27535c2 Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Wed, 3 Dec 2025 14:20:46 -0800 Subject: [PATCH 1/4] Escaped regex characters in variable_dims() --- R/csv.R | 4 +++- tests/testthat/test-csv.R | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/R/csv.R b/R/csv.R index 8ac91393..d557aeb6 100644 --- a/R/csv.R +++ b/R/csv.R @@ -983,7 +983,9 @@ variable_dims <- function(variable_names = NULL) { uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names)) var_names <- gsub("\\]", "", variable_names) for (var in uniq_variable_names) { - pattern <- paste0("^", var, "\\[") + # escape regex symbols + esc_var <- gsub("([][{}()+*?.^$|\\\\])", "\\\\\\1", var) + pattern <- paste0("^", esc_var, "\\[") var_indices <- var_names[grep(pattern, var_names)] var_indices <- gsub(pattern, "", var_indices) if (length(var_indices)) { diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index dc678a6b..5624a079 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -95,6 +95,19 @@ test_that("read_cmdstan_csv() fails with the no params listed", { "Supplied CSV file does not contain any variable names or data!") }) +test_that("variable_dims works for standard Stan names", { + vars <- c("beta[1]", "beta[2]", "beta[3]", "sigma") + dims <- variable_dims(vars) + expect_equal(dims$beta, 3) + expect_equal(dims$sigma, 1) +}) + +test_that("variable_dims handles names with regex metacharacters", { + vars <- c('SIGMA(1,1)[1,2]', 'SIGMA(1,1)[2,2]') + dims <- variable_dims(vars) + expect_equal(dims[["SIGMA(1,1)"]], c(2, 2)) +}) + test_that("read_cmdstan_csv() matches utils::read.csv", { csv_files <- c(test_path("resources", "csv", "model1-1-warmup.csv"), test_path("resources", "csv", "model1-2-warmup.csv")) From 419b39c50ddf46fb3b211d8a81cddfd344373bb6 Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:35:19 -0800 Subject: [PATCH 2/4] Change parens in variables to brackets. Added a LLM-gen test --- R/csv.R | 4 +++- tests/testthat/test-csv.R | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/R/csv.R b/R/csv.R index d557aeb6..a0f998e3 100644 --- a/R/csv.R +++ b/R/csv.R @@ -915,8 +915,10 @@ check_csv_metadata_matches <- function(csv_metadata) { NULL } -# convert names like beta.1.1 to beta[1,1] +# convert names like beta.1.1 or beta(1,1) to beta[1,1] repair_variable_names <- function(names) { + names <- sub("\\(", "[", names) + names <- gsub("\\)", "", names) names <- sub("\\.", "[", names) names <- gsub("\\.", ",", names) names[grep("\\[", names)] <- diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 5624a079..414c7f49 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -909,3 +909,39 @@ test_that("read_cmdstan_csv() works with tilde expansion", { tildified_path <- file.path("~", fs::path_rel(full_path, "~")) expect_no_error(read_cmdstan_csv(tildified_path)) }) + +test_that("as_cmdstan_fit handles variable names with parentheses", { + csv_file <- tempfile(fileext = ".csv") + writeLines(c( + "# stan_version_major = 2", + "# stan_version_minor = 33", + "# stan_version_patch = 0", + "# model = norm_model", + "# method = sample (Default)", + "# sample", + "# num_samples = 2", + "# num_warmup = 0", + "# save_warmup = 0", + "# thin = 1", + "# random", + "# seed = 123", + "# algorithm = hmc", + "# metric = diag_e", + "# stepsize = 1", + "# id = 1", + "THETA4,SIGMA(1,1)", + "2.00000E+00,2.00000E+00", + "2.00000E+00,2.00000E+00" + ), con = csv_file) + + expect_no_error({ + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE, format = "draws_matrix") + }) + + draws <- fit$draws() + vars <- posterior::variables(draws) + + expect_equal(posterior::ndraws(draws), 2L) + expect_true(any(grepl("THETA4", vars))) + expect_true(any(grepl("SIGMA", vars))) +}) From 93e2b321b16a92cdcc5702bff91edcac8746a90f Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:01:57 -0800 Subject: [PATCH 3/4] Rely on fread for header parsing --- R/csv.R | 8 ++++++- tests/testthat/test-csv.R | 48 ++++++++++++++++----------------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/R/csv.R b/R/csv.R index a0f998e3..a091fc19 100644 --- a/R/csv.R +++ b/R/csv.R @@ -697,7 +697,13 @@ read_csv_metadata <- function(csv_file) { for (line in metadata[[1]]) { if (!startsWith(line, "#") && is.null(csv_file_info[["variables"]])) { # if no # at the start of line, the line is the CSV header - all_names <- strsplit(line, ",")[[1]] + header_dt <- data.table::fread( + text = line, + header = FALSE, + check.names = FALSE, + data.table = FALSE + ) + all_names <- as.character(header_dt[1, ]) if (all(csv_file_info$algorithm != "fixed_param")) { csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")] csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))] diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 414c7f49..c068ed10 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -910,38 +910,30 @@ test_that("read_cmdstan_csv() works with tilde expansion", { expect_no_error(read_cmdstan_csv(tildified_path)) }) -test_that("as_cmdstan_fit handles variable names with parentheses", { - csv_file <- tempfile(fileext = ".csv") - writeLines(c( - "# stan_version_major = 2", - "# stan_version_minor = 33", - "# stan_version_patch = 0", +test_that("as_cmdstan_fit handles parameter names with parentheses and indices", { + skip_on_cran() + + csv_file <- withr::local_tempfile(fileext = ".csv") + + lines <- c( "# model = norm_model", "# method = sample (Default)", - "# sample", - "# num_samples = 2", - "# num_warmup = 0", - "# save_warmup = 0", - "# thin = 1", - "# random", - "# seed = 123", - "# algorithm = hmc", - "# metric = diag_e", - "# stepsize = 1", "# id = 1", - "THETA4,SIGMA(1,1)", - "2.00000E+00,2.00000E+00", - "2.00000E+00,2.00000E+00" - ), con = csv_file) + "# thin = 1", + "# save_warmup = 0", + 'lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,"Sigma(1,1)","Sigma(1,2)","Sigma(2,1)","Sigma(2,2)"', + "-65.951579,0.92571393,0.77752825,3,7,0,67.391073,0.2808549,-0.95718644,0.080662461,0.58814086", + "-66.417297,0.89632515,0.77752825,2,3,0,68.026905,0.3014893,-0.97834703,0.069719538,0.89573157" + ) + writeLines(lines, csv_file) - expect_no_error({ - fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE, format = "draws_matrix") - }) + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE) - draws <- fit$draws() - vars <- posterior::variables(draws) + vars <- posterior::variables(fit$draws()) + expect_true(all( + c("Sigma[1,1]", "Sigma[1,2]", "Sigma[2,1]", "Sigma[2,2]") %in% vars + )) - expect_equal(posterior::ndraws(draws), 2L) - expect_true(any(grepl("THETA4", vars))) - expect_true(any(grepl("SIGMA", vars))) + dims <- fit$metadata()$stan_variable_sizes + expect_equal(dims[["Sigma"]], c(2, 2)) }) From 5527cdb97ad6b5173075bd13aae8ce7540c75404 Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:11:33 -0800 Subject: [PATCH 4/4] Re-added old test --- tests/testthat/test-csv.R | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index c068ed10..827c3cc9 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -914,7 +914,6 @@ test_that("as_cmdstan_fit handles parameter names with parentheses and indices", skip_on_cran() csv_file <- withr::local_tempfile(fileext = ".csv") - lines <- c( "# model = norm_model", "# method = sample (Default)", @@ -937,3 +936,29 @@ test_that("as_cmdstan_fit handles parameter names with parentheses and indices", dims <- fit$metadata()$stan_variable_sizes expect_equal(dims[["Sigma"]], c(2, 2)) }) + +test_that("as_cmdstan_fit handles variable names with parentheses", { + skip_on_cran() + csv_file <- withr::local_tempfile(fileext = ".csv") + writeLines(c( + "# model = norm_model", + "# method = sample (Default)", + "# id = 1", + "# thin = 1", + "# save_warmup = 0", + "THETA4,SIGMA(1,1)", + "2.00000E+00,2.00000E+00", + "2.00000E+00,2.00000E+00" + ), con = csv_file) + + expect_no_error({ + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE, format = "draws_matrix") + }) + + draws <- fit$draws() + vars <- posterior::variables(draws) + + expect_equal(posterior::ndraws(draws), 2L) + expect_true(any(grepl("THETA4", vars))) + expect_true(any(grepl("SIGMA", vars))) +})