Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ CmdStanFit$set("public", name = "init", value = init)
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (logical) Whether to show verbose logging during compilation.
#' @param hessian (logical) Whether to expose the (experimental) hessian method.
#' @param force_recompile (logical) Whether to recompile cached model methods.
#'
#' @examples
#' \dontrun{
Expand All @@ -332,25 +333,26 @@ CmdStanFit$set("public", name = "init", value = init)
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE, force_recompile = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
call. = FALSE)
}
require_suggested_package("Rcpp")
if (length(private$model_methods_env_$hpp_code_) == 0) {
if (length(private$model_methods_env_$hpp_code_) == 0 && (
is.null(private$model_methods_env_$obj_file_) ||
!file.exists(private$model_methods_env_$obj_file_))) {
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
"the model must be compiled again", call. = FALSE)
}
if (hessian) {
message("The hessian method relies on higher-order autodiff ",
"which is still experimental. Please report any compilation ",
"errors that you encounter")
warning("The hessian argument is deprecated and will be removed in a future release.\n",
"The hessian method is now exposed by default.")
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
expose_model_methods(private$model_methods_env_, verbose = verbose,
force_recompile = force_recompile)
}
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
invisible(NULL)
Expand Down
2 changes: 2 additions & 0 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ build_cmdstan <- function(dir,
clean_cmdstan <- function(dir = cmdstan_path(),
cores = getOption("mc.cores", 2),
quiet = FALSE) {
unlink(file.path(dir, "model_methods.o"))
unlink(file.path(dir, "model_methods.cpp"))
withr::with_path(
c(
toolchain_PATH_env_var(),
Expand Down
18 changes: 11 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ CmdStanModel <- R6::R6Class(
#' @param compile_model_methods (logical) Compile additional model methods
#' (`log_prob()`, `grad_log_prob()`, `constrain_variables()`,
#' `unconstrain_variables()`).
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
#' be compiled with the model methods?
#' @param compile_hessian_method (logical) Deprecated and will be removed in a future release.
#' The hessian method is now compiled by default.
#' @param compile_standalone (logical) Should functions in the Stan model be
#' compiled for use in R? If `TRUE` the functions will be available via the
#' `functions` field in the compiled model object. This can also be done after
Expand Down Expand Up @@ -504,6 +504,10 @@ compile <- function(quiet = TRUE,
warning("'threads' is deprecated. Please use 'cpp_options = list(stan_threads = TRUE)' instead.")
cpp_options[["stan_threads"]] <- TRUE
}
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is now compiled by default.")
compile_hessian_method <- FALSE
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
Expand Down Expand Up @@ -655,9 +659,10 @@ compile <- function(quiet = TRUE,
run_log <- wsl_compatible_run(
command = make_cmd(),
args = c(wsl_safe_path(tmp_exe),
cpp_options_to_compile_flags(cpp_options),
cpp_options_to_compile_flags(c(cpp_options, list("KEEP_OBJECT"="true"))),
stancflags_val),
wd = cmdstan_path(),
env = c("current", "CXXFLAGS" = "-fPIC"),
echo = !quiet || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
spinner = quiet && rlang::is_interactive() && !identical(Sys.getenv("IN_PKGDOWN"), "true"),
Expand Down Expand Up @@ -708,6 +713,7 @@ compile <- function(quiet = TRUE,
file.remove(exe)
}
file.copy(tmp_exe, exe, overwrite = TRUE)
private$model_methods_env_$obj_file_ <- paste0(temp_file_no_ext, ".o")
if (os_is_wsl()) {
res <- processx::run(
command = "wsl",
Expand All @@ -726,11 +732,9 @@ compile <- function(quiet = TRUE,
private$precompile_stanc_options_ <- NULL
private$precompile_include_paths_ <- NULL

if(!dry_run) {
if (!dry_run) {
if (compile_model_methods) {
expose_model_methods(env = private$model_methods_env_,
verbose = !quiet,
hessian = compile_hessian_method)
expose_model_methods(private$model_methods_env_, verbose = !quiet)
}
}
invisible(self)
Expand Down
117 changes: 94 additions & 23 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -728,45 +728,116 @@ get_cmdstan_flags <- function(flag_name) {
paste(flags, collapse = " ")
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags <- function(expr, model_methods = FALSE) {
cxxflags <- get_cmdstan_flags("CXXFLAGS")
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
cmdstanr_includes <- paste0("-I", shQuote(cmdstanr_includes))

r_includes <- paste(
paste0("-I", shQuote(system.file("include", package = "Rcpp", mustWork = TRUE))),
paste0("-I", shQuote(R.home(component = "include")))
)

libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = " ")
if (.Platform$OS.type == "windows") {
if (os_is_windows()) {
libs <- paste(libs, "-fopenmp")
}
lib_paths <- c("/stan/lib/stan_math/lib/tbb/",
"/stan/lib/stan_math/lib/sundials_6.1.1/lib/")
withr::with_path(paste0(cmdstan_path(), lib_paths),
withr::with_makevars(
c(
USE_CXX14 = 1,
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
PKG_LIBS = libs
),
Rcpp::sourceCpp(code = code, env = env, verbose = verbose)
new_makevars <- c(
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste(cxxflags, cmdstanr_includes, r_includes, collapse = " "),
PKG_LIBS = libs
)
if (os_is_windows() && model_methods) {
new_makevars <- c(
new_makevars,
SHLIB_LD = paste0(rtools4x_toolchain_path(),"/gcc"),
LOCAL_CPPFLAGS = paste0("-I'",rtools4x_toolchain_path(),"/../include'"),
LOCAL_LIBS = paste0("-L'",rtools4x_toolchain_path(),"/../lib' -lstdc++"),
BINPREF = paste0(rtools4x_toolchain_path(), "/")
)
}
withr::with_path(
c(
paste0(cmdstan_path(), lib_paths),
toolchain_PATH_env_var()
),
withr::with_makevars(new_makevars, expr)
)
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags(Rcpp::sourceCpp(code = code, env = env, verbose = verbose))
invisible(NULL)
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))
initialize_method_functions <- function(env, so_name) {
env$model_ptr <-
function(...) { .Call("model_ptr_", ..., PACKAGE = so_name) }
env$log_prob <-
function(...) { .Call("log_prob_", ..., PACKAGE = so_name) }
env$grad_log_prob <-
function(...) { .Call("grad_log_prob_", ..., PACKAGE = so_name) }
env$hessian <-
function(...) { .Call("hessian_", ..., PACKAGE = so_name) }
env$get_num_upars <-
function(...) { .Call("get_num_upars_", ..., PACKAGE = so_name) }
env$get_param_metadata <-
function(...) { .Call("get_param_metadata_", ..., PACKAGE = so_name) }
env$unconstrain_variables <-
function(...) { .Call("unconstrain_variables_", ..., PACKAGE = so_name) }
env$constrain_variables <-
function(...) { .Call("constrain_variables_", ..., PACKAGE = so_name) }
env$unconstrained_param_names <-
function(...) { .Call("unconstrained_param_names_", ..., PACKAGE = so_name) }
env$constrained_param_names <-
function(...) { .Call("constrained_param_names_", ..., PACKAGE = so_name) }
}

if (hessian) {
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
expose_model_methods <- function(env, force_recompile = FALSE, verbose = FALSE) {
precomp_methods_file <- file.path(cmdstan_path(), "model_methods.o")
if (file.exists(precomp_methods_file) && force_recompile) {
unlink(precomp_methods_file)
}
model_methods_cpp <- system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)
source_file <- paste0(strip_ext(precomp_methods_file), ".cpp")
file.copy(model_methods_cpp, source_file, overwrite = FALSE)

model_obj_file <- env$obj_file_
if (!file.exists(model_obj_file)) {
if (rlang::is_interactive()) {
message("Model object file not found, recompiling model...")
}
temp_hpp_file <- tempfile()
writeLines(env$hpp_code_, con = paste0(temp_hpp_file, ".cpp"))
model_obj_file <- paste0(temp_hpp_file, ".o")
}

if (!file.exists(precomp_methods_file) && rlang::is_interactive()) {
message("Compiling and caching additional model methods...")
}
if (rlang::is_interactive()) {
message("Linking precompiled model methods to model object file...")
}

methods_dll <- tempfile(fileext = .Platform$dynlib.ext)
with_cmdstan_flags(
processx::run(
command = file.path(R.home(component = "bin"), "R"),
args = c("CMD", "SHLIB", repair_path(model_obj_file), repair_path(precomp_methods_file),
"-o", repair_path(methods_dll)),
echo = verbose || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
error_on_status = FALSE
),
model_methods = TRUE
)

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
env$methods_dll_info <- with_cmdstan_flags(dyn.load(methods_dll, local = TRUE, now = TRUE))
initialize_method_functions(env, strip_ext(basename(methods_dll)))
invisible(NULL)
}

Expand Down
41 changes: 0 additions & 41 deletions inst/include/hessian.cpp

This file was deleted.

Loading