diff --git a/lib/explorer/backend/sql_context.ex b/lib/explorer/backend/sql_context.ex new file mode 100644 index 000000000..0754ba1e1 --- /dev/null +++ b/lib/explorer/backend/sql_context.ex @@ -0,0 +1,13 @@ +defmodule Explorer.Backend.SQLContext do + @type t :: struct() + @type c :: Explorer.SQLContext.t() + @type df :: Explorer.DataFrame.t() + @type result(t) :: {:ok, t} | {:error, term()} + + @callback register(c, String.t(), df) :: c + @callback unregister(c, String.t()) :: c + @callback execute(c, String.t()) :: result(df) + @callback get_tables(c) :: list(String.t()) + + def new(ctx), do: %Explorer.SQLContext{ctx: ctx} +end diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 36c157bb4..b0bf36fba 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -465,5 +465,11 @@ defmodule Explorer.PolarsBackend.Native do def message_on_gc(_pid, _payload), do: err() def is_message_on_gc(_term), do: err() + def sql_context_new(), do: err() + def sql_context_register(_ctx, _name, _df), do: err() + def sql_context_unregister(_ctx, _name), do: err() + def sql_context_execute(_ctx, _query), do: err() + def sql_context_get_tables(_ctx), do: err() + defp err, do: :erlang.nif_error(:nif_not_loaded) end diff --git a/lib/explorer/polars_backend/sql_context.ex b/lib/explorer/polars_backend/sql_context.ex new file mode 100644 index 000000000..22a975610 --- /dev/null +++ b/lib/explorer/polars_backend/sql_context.ex @@ -0,0 +1,37 @@ +defmodule Explorer.PolarsBackend.SQLContext do + @moduledoc false + + defstruct resource: nil + + alias Explorer.Native + alias Explorer.PolarsBackend.Native + alias Explorer.PolarsBackend.Shared + + @type t :: %__MODULE__{resource: reference()} + + @behaviour Explorer.Backend.SQLContext + + def new() do + ctx = Native.sql_context_new() + Explorer.Backend.SQLContext.new(ctx) + end + + def register(%Explorer.SQLContext{ctx: ctx} = context, name, %Explorer.DataFrame{data: df}) do + Native.sql_context_register(ctx, name, df) + context + end + + def unregister(%Explorer.SQLContext{ctx: ctx} = context, name) do + Native.sql_context_unregister(ctx, name) + context + end + + def execute(%Explorer.SQLContext{ctx: ctx}, query) do + case Native.sql_context_execute(ctx, query) do + {:ok, polars_ldf} -> Shared.create_dataframe(polars_ldf) + {:error, error} -> {:error, RuntimeError.exception(error)} + end + end + + def get_tables(%Explorer.SQLContext{ctx: ctx}), do: Native.sql_context_get_tables(ctx) +end diff --git a/lib/explorer/sql_context.ex b/lib/explorer/sql_context.ex new file mode 100644 index 000000000..2eb38c20e --- /dev/null +++ b/lib/explorer/sql_context.ex @@ -0,0 +1,33 @@ +defmodule Explorer.SQLContext do + @enforce_keys [:ctx] + defstruct [:ctx] + + alias __MODULE__, as: SQLContext + + @type t :: %SQLContext{ctx: Explorer.Backend.SQLContext.t()} + + alias Explorer.Backend.SQLContext + alias Explorer.Shared + + def new(args \\ [], opts \\ []), do: Shared.apply_init(backend(), :new, args, opts) + + def register(ctx, name, df, opts \\ []) do + Shared.apply_init(backend(), :register, [ctx, name, df], opts) + end + + def unregister(ctx, name, opts \\ []) do + Shared.apply_init(backend(), :unregister, [ctx, name], opts) + end + + def execute(ctx, query, opts \\ []) do + Shared.apply_init(backend(), :execute, [ctx, query], opts) + end + + def get_tables(ctx, opts \\ []) do + Shared.apply_init(backend(), :get_tables, [ctx], opts) + end + + defp backend do + Module.concat([Explorer.Backend.get(), "SQLContext"]) + end +end diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index d16015f84..ac078ce8c 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -24,6 +24,7 @@ mod expressions; mod lazyframe; mod local_message; mod series; +mod sql_context; pub use datatypes::{ ExDataFrame, ExDataFrameRef, ExExpr, ExExprRef, ExLazyFrame, ExLazyFrameRef, ExSeries, @@ -33,6 +34,7 @@ pub use datatypes::{ pub use error::ExplorerError; use expressions::*; use series::*; +pub use sql_context::*; mod atoms { rustler::atoms! { diff --git a/native/explorer/src/sql_context.rs b/native/explorer/src/sql_context.rs new file mode 100644 index 000000000..860a1502b --- /dev/null +++ b/native/explorer/src/sql_context.rs @@ -0,0 +1,70 @@ +use crate::{ExDataFrame, ExLazyFrame, ExplorerError}; +use polars::{prelude::IntoLazy, sql::SQLContext}; +use rustler::{NifStruct, Resource, ResourceArc}; +use std::sync::{Arc, Mutex}; +pub struct ExSQLContextRef(pub Arc>); + +#[rustler::resource_impl] +impl Resource for ExSQLContextRef {} + +#[derive(NifStruct)] +#[module = "Explorer.PolarsBackend.SQLContext"] +pub struct ExSQLContext { + pub resource: ResourceArc, +} + +impl ExSQLContextRef { + pub fn new(ctx: SQLContext) -> Self { + Self(Arc::new(Mutex::new(ctx))) + } +} + +impl ExSQLContext { + pub fn new(ctx: SQLContext) -> Self { + Self { + resource: ResourceArc::new(ExSQLContextRef::new(ctx)), + } + } + + // Function to get a lock on the inner SQLContext + pub fn lock_inner(&self) -> std::sync::MutexGuard { + self.resource.0.lock().unwrap() + } +} + +#[rustler::nif] +fn sql_context_new() -> ExSQLContext { + let ctx = SQLContext::new(); + ExSQLContext::new(ctx) +} + +#[rustler::nif] +fn sql_context_register(context: ExSQLContext, name: &str, df: ExDataFrame) { + let mut ctx = context.lock_inner(); + let ldf = df.clone_inner().lazy(); + ctx.register(name, ldf) +} + +#[rustler::nif] +fn sql_context_unregister(context: ExSQLContext, name: &str) { + let mut ctx = context.lock_inner(); + ctx.unregister(name) +} + +#[rustler::nif] +fn sql_context_execute(context: ExSQLContext, query: &str) -> Result { + let mut ctx = context.lock_inner(); + match ctx.execute(query) { + Ok(lazy_frame) => Ok(ExLazyFrame::new(lazy_frame)), + Err(e) => Err(ExplorerError::Other(format!( + "Failed to execute query: {}", + e + ))), + } +} + +#[rustler::nif] +fn sql_context_get_tables(context: ExSQLContext) -> Vec { + let ctx = context.lock_inner(); + ctx.get_tables() +} diff --git a/test/explorer/sql_context_test.exs b/test/explorer/sql_context_test.exs new file mode 100644 index 000000000..04c8214f9 --- /dev/null +++ b/test/explorer/sql_context_test.exs @@ -0,0 +1,89 @@ +defmodule Explorer.SQLContextTest do + use ExUnit.Case, async: true + + require Explorer.DataFrame + + alias Explorer.DataFrame, as: DF + alias Explorer.SQLContext + + describe "execute" do + test "execute without any data frame registered" do + case SQLContext.new() + |> SQLContext.execute("select 1 as column_a union all select 2 as column_a") do + {:ok, result} -> + assert result != nil + assert DF.compute(result) |> DF.to_columns(atom_keys: true) == %{column_a: [1, 2]} + + {:error, reason} -> + flunk("SQL query execution failed with reason: #{inspect(reason)}") + end + end + + test "execute with registering single data frame" do + df = DF.new(%{column_a: [1, 2, 3]}) + + case SQLContext.new() + |> SQLContext.register("t1", df) + |> SQLContext.execute( + "select 2 * t.column_a as column_2a from t1 as t where t.column_a < 3" + ) do + {:ok, result} -> + assert result != nil + assert DF.compute(result) |> DF.to_columns(atom_keys: true) == %{column_2a: [2, 4]} + + {:error, reason} -> + flunk("SQL query execution failed with reason: #{inspect(reason)}") + end + end + + test "execute with registering multiple data frames" do + df1 = DF.new(%{column_1a: [1, 2, 3]}) + + df2 = + DF.new(%{ + column_2a: [1, 2, 4], + column_2b: ["a", "b", "c"] + }) + + case SQLContext.new() + |> SQLContext.register("t1", df1) + |> SQLContext.register("t2", df2) + |> SQLContext.execute( + "select t2.column_2b as col from t1 join t2 on t1.column_1a = t2.column_2a" + ) do + {:ok, result} -> + assert result != nil + assert DF.compute(result) |> DF.to_columns(atom_keys: true) == %{col: ["a", "b"]} + + {:error, reason} -> + flunk("SQL query execution failed with reason: #{inspect(reason)}") + end + end + + test "get_tables get registered tables" do + df = DF.new(%{col: [1]}) + + tables = + SQLContext.new() + |> SQLContext.register("t1", df) + |> SQLContext.register("t2", df) + |> SQLContext.get_tables() + + assert tables == ["t1", "t2"] + end + + test "unregister" do + df = DF.new(%{col: [1]}) + + tables = + SQLContext.new() + |> SQLContext.register("t1", df) + |> SQLContext.register("t2", df) + |> SQLContext.register("t3", df) + |> SQLContext.unregister("t1") + |> SQLContext.get_tables() + + assert tables == ["t2", "t3"] + end + end +end