Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ def test():
return available_backends, fused_attention_backend

backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
# with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends


Expand Down
138 changes: 73 additions & 65 deletions transformer_engine/common/fused_attn/fused_attn.cpp

Large diffs are not rendered by default.

174 changes: 101 additions & 73 deletions transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
Expand All @@ -63,21 +64,23 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

Expand Down
Loading
Loading