Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ if (NOT ENABLE_ONEMKL_SYCL AND
NOT ENABLE_ROCSPARSE AND
NOT ENABLE_CUSPARSE)
set(SPBLAS_CPU_BACKEND ON)
set(SPBLAS_REFERENCE_BACKEND ON)
endif()

# turn on/off debug logging
Expand Down
3 changes: 3 additions & 0 deletions include/spblas/algorithms/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
#include <spblas/algorithms/scaled.hpp>
#include <spblas/algorithms/scaled_impl.hpp>

#include <spblas/algorithms/conjugated.hpp>
#include <spblas/algorithms/conjugated_impl.hpp>

#include <spblas/algorithms/transpose.hpp>
#include <spblas/algorithms/transpose_impl.hpp>
13 changes: 13 additions & 0 deletions include/spblas/algorithms/conjugated.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <spblas/concepts.hpp>

namespace spblas {

template <matrix M>
auto conjugated(M&& m);

template <vector V>
auto conjugated(V&& v);

} // namespace spblas
30 changes: 30 additions & 0 deletions include/spblas/algorithms/conjugated_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <complex>
#include <utility>

#include <spblas/concepts.hpp>
#include <spblas/detail/type_traits.hpp>
#include <spblas/views/conjugated_view.hpp>

namespace spblas {

template <vector V>
auto conjugated(V&& v) {
if constexpr (__detail::is_std_complex_v<tensor_scalar_t<V>>) {
return conjugated_view(std::forward<V>(v));
} else {
return std::forward<V>(v);
}
}

template <matrix M>
auto conjugated(M&& m) {
if constexpr (__detail::is_std_complex_v<tensor_scalar_t<M>>) {
return conjugated_view(std::forward<M>(m));
} else {
return std::forward<M>(m);
}
}

} // namespace spblas
54 changes: 47 additions & 7 deletions include/spblas/backend/generate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,46 @@
#include <numeric>
#include <random>
#include <set>
#include <tuple>
#include <type_traits>
#include <vector>

#include <spblas/detail/ranges.hpp>
#include <spblas/detail/type_traits.hpp>

namespace spblas {

namespace __detail {

template <typename T, typename G>
T random_uniform(G& g, T lo, T hi) {
if constexpr (is_std_complex_v<T>) {
using value_type = typename std::remove_cvref_t<T>::value_type;
std::uniform_real_distribution<value_type> d(lo.real(), hi.real());
return T(d(g), d(g));
} else if constexpr (std::is_integral_v<T>) {
std::uniform_int_distribution<T> d(lo, hi);
return d(g);
} else {
std::uniform_real_distribution<T> d(lo, hi);
return d(g);
}
}

template <typename T, typename G>
T random_gaussian(G& g, T mean, T stddev) {
if constexpr (is_std_complex_v<T>) {
using value_type = typename std::remove_cvref_t<T>::value_type;
std::normal_distribution<value_type> d(mean.real(), stddev.real());
return T(d(g), d(g));
} else {
std::normal_distribution<T> d(mean, stddev);
return d(g);
}
}

} // namespace __detail

template <typename T = float, typename I = index_t>
auto generate_coo(std::size_t m, std::size_t n, std::size_t nnz,
std::size_t seed = 0) {
Expand All @@ -25,7 +59,6 @@ auto generate_coo(std::size_t m, std::size_t n, std::size_t nnz,
std::mt19937 g(seed);
std::uniform_int_distribution<I> d_row(0, m - 1);
std::uniform_int_distribution<I> d_col(0, n - 1);
std::uniform_real_distribution d_val(0.0, 100.0);
std::set<std::pair<I, I>> entries;

for (std::size_t i = 0; i < nnz; i++) {
Expand All @@ -38,10 +71,19 @@ auto generate_coo(std::size_t m, std::size_t n, std::size_t nnz,
entries.emplace(row, col);
rowind.push_back(row);
colind.push_back(col);
values.push_back(d_val(g));
values.push_back(__detail::random_uniform<T>(g, T{0}, T{100}));
}

__ranges::sort(__ranges::views::zip(rowind, colind, values));
__ranges::sort(__ranges::views::zip(rowind, colind, values),
[](const auto& left, const auto& right) {
if (std::get<0>(left) < std::get<0>(right)) {
return true;
}
if (std::get<0>(right) < std::get<0>(left)) {
return false;
}
return std::get<1>(left) < std::get<1>(right);
});

return std::tuple(values, rowind, colind, spblas::index<I>(m, n), I(nnz));
}
Expand Down Expand Up @@ -131,10 +173,9 @@ auto generate_dense(std::size_t m, std::size_t n, std::size_t seed = 0) {
v.reserve(m * n);

std::mt19937 g(seed);
std::uniform_real_distribution d(0.0, 100.0);

for (std::size_t i = 0; i < m * n; i++) {
v.push_back(d(g));
v.push_back(__detail::random_uniform<T>(g, T{0}, T{100}));
}

return std::tuple(v, spblas::index(m, n));
Expand All @@ -146,10 +187,9 @@ auto generate_gaussian(std::size_t m, std::size_t n, std::size_t seed = 0) {
v.reserve(m * n);

std::mt19937 g(seed);
std::normal_distribution d{0.0, 1.0};

for (std::size_t i = 0; i < m * n; i++) {
v.push_back(d(g));
v.push_back(__detail::random_gaussian<T>(g, T{0}, T{1}));
}

return std::tuple(v, spblas::index(m, n));
Expand Down
1 change: 1 addition & 0 deletions include/spblas/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace spblas {
- Instantiations of csc_view<...>
- Instantiations of mdspan<...> with rank 2
- Instantiations of scaled_view<T> where M is a matrix
- Instantiations of conjugated_view<T> where M is a matrix
*/

template <typename M>
Expand Down
22 changes: 22 additions & 0 deletions include/spblas/detail/type_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <complex>
#include <type_traits>

namespace spblas {

namespace __detail {

template <typename T>
struct is_std_complex : std::false_type {};

template <typename T>
struct is_std_complex<std::complex<T>> : std::true_type {};

template <typename T>
inline constexpr bool is_std_complex_v =
is_std_complex<std::remove_cvref_t<T>>::value;

} // namespace __detail

} // namespace spblas
20 changes: 20 additions & 0 deletions include/spblas/detail/view_inspectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ auto get_scaling_factor(T&& t, U&& u) {
}
}

// Inspect a tensor: does it have a conjugation view? Returns true if an odd
// number of conjugated_view appears in the chain of bases.
template <tensor T>
bool is_conjugated(T&& t) {
if constexpr (has_base<T>) {
if constexpr (is_conjugated_view_v<T>) {
return !is_conjugated(t.base());
} else {
return is_conjugated(t.base());
}
} else {
return is_conjugated_view_v<T>;
}
}

template <tensor T, tensor U>
bool is_conjugated(T&& t, U&& u) {
return is_conjugated(std::forward<T>(t)) || is_conjugated(std::forward<U>(u));
}

template <tensor T>
bool has_scaling_factor(T&& t) {
return get_scaling_factor(t).has_value();
Expand Down
13 changes: 13 additions & 0 deletions include/spblas/vendor/aoclsparse/spgemm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <aoclsparse.h>
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include "detail/detail.hpp"
Expand Down Expand Up @@ -40,6 +41,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}
Comment on lines +44 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so here is a design question it seems like we should have the following checks:

not transposed + not conjugate == is_nontransposed
not transposed + conjugate == is_conjugated
transposed + not conjugate == is_transposed
transposed + conjugate == is_conjugate_transposed

MKL and AOCL and other libraries support nontranspose/transpose/conjugatetranspose but not conjugate. in the is_conjugated() function do we need to rule out is_transposed() so we can correctly capture conjugate_transpose ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind, it just hasn't been implemented yet here. see mkl backend for actual implementation.

Copy link
Contributor

@spencerpatty spencerpatty Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User Data: input is CSR

csr_view<I,O,T>csrV(/* user data in csr */);
matrix_handle csrA(csrV);

sparse::spmv(transposed(csrA), x,y) // y = A^T * x
sparse::spmv(conjugated(transposed(csrA)), x,y) // y = A^H * x
sparse::spmm(csrA,W,Z)  // Z = A*W

converts to in MKL backend

mkl::sparse::matrix_handle_t csrA;
mkl::sparse::set_csr_data(csrA, /* user data */)
mkl::sparse::gemv(csrA, transpose::trans, x,y)
mkl::sparse::gemv(csrA, transpose::conjtrans, x,y)
mkl::sparse::gemm(csrA, transpose::nontrans, W, Z)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is orthogonal to PR, but as this is a common use case, we will want to add a transposed_view object to the mix eventually.


using T = tensor_scalar_t<C>;
using I = tensor_index_t<C>;
using O = tensor_offset_t<C>;
Expand Down Expand Up @@ -111,6 +118,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
using I = tensor_index_t<C>;
using O = tensor_offset_t<C>;

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/spmm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -44,6 +45,12 @@ void multiply(A&& a, X&& x, Y&& y) {
auto x_base = __detail::get_ultimate_base(x);
auto y_base = __detail::get_ultimate_base(y);

if (__detail::is_conjugated(a) || __detail::is_conjugated(x) ||
__detail::is_conjugated(y)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base);

aoclsparse_operation opA = __aoclsparse::get_transpose(a);
Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -38,6 +39,12 @@ void multiply(A&& a, X&& x, Y&& y) {
auto a_base = __detail::get_ultimate_base(a);
auto x_base = __detail::get_ultimate_base(x);

if (__detail::is_conjugated(a) || __detail::is_conjugated(x) ||
__detail::is_conjugated(y)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base);
aoclsparse_operation opA = __aoclsparse::get_transpose(a);

Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -46,6 +47,12 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(x)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

using T = tensor_scalar_t<A>;
using I = tensor_index_t<A>;
using O = tensor_offset_t<A>;
Expand Down
26 changes: 26 additions & 0 deletions include/spblas/vendor/armpl/multiply_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <spblas/vendor/armpl/detail/detail.hpp>

#include <stdexcept>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
Expand All @@ -21,6 +23,12 @@ void multiply(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand All @@ -47,6 +55,12 @@ void multiply(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down Expand Up @@ -92,6 +106,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down Expand Up @@ -121,6 +141,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
auto c_handle = info.state_.c_handle;

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

__armpl::export_matrix_handle(info, c, c_handle);
}

Expand Down
Loading
Loading