From 06dcc06794311a4bb2ef13c45c5973d5ebabdbec Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Wed, 14 Jan 2026 10:11:00 +0000 Subject: [PATCH] issue/920 RoPE supports longrope --- include/infinicore/nn/rope.hpp | 56 ++++++++++++++++++++++++++++++---- src/infinicore/nn/rope.cc | 28 ++++++++++++++--- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/include/infinicore/nn/rope.hpp b/include/infinicore/nn/rope.hpp index 26f01a413..0f12ff117 100644 --- a/include/infinicore/nn/rope.hpp +++ b/include/infinicore/nn/rope.hpp @@ -17,6 +17,47 @@ class RoPE : public Module { GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos) }; + enum class ScalingType { + DEFAULT = 0, // Default RoPE + LONGROPE = 1 // Long-RoPE + }; + + class ScalingConfig { + public: + virtual ~ScalingConfig() = default; + ScalingType type() const { return type_; } + + protected: + ScalingType type_ = ScalingType::DEFAULT; + ScalingConfig(ScalingType type) : type_(type) {} + }; + + // longrope scaling + class LongRopeConfig : public ScalingConfig { + protected: + std::vector short_factor_; + std::vector long_factor_; + size_t original_max_position_embeddings_; + float factor_; + + public: + LongRopeConfig( + std::vector short_factor, + std::vector long_factor, + size_t original_max_position_embeddings, + float factor = 1.0f) + : ScalingConfig(ScalingType::LONGROPE), + short_factor_(short_factor), + long_factor_(long_factor), + original_max_position_embeddings_(original_max_position_embeddings), + factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {} + ~LongRopeConfig() override = default; + size_t original_max_position_embeddings() const { return original_max_position_embeddings_; } + const std::vector &short_factor() const { return short_factor_; } + const std::vector &long_factor() const { return long_factor_; } + float factor() const { return factor_; } + }; + /** * @brief Construct a RoPE layer * @@ -26,13 +67,15 @@ class RoPE : public Module { * @param algo RoPE algorithm type (default: Algo::GPT_J) * @param dtype Data type for sin/cos cache (default: DataType::F32) * @param device Device to create the cache on + * @param scaling RoPE scaling type (default: nullptr) */ RoPE(size_t head_dim, size_t max_seq_len, double theta = 10000.0, Algo algo = Algo::GPT_J, const DataType &dtype = DataType::F32, - const Device &device = Device()); + const Device &device = Device(), + std::shared_ptr scaling = nullptr); /** * @brief Forward pass: apply RoPE to a tensor @@ -88,11 +131,12 @@ class RoPE : public Module { private: void initialize_cache(); - size_t head_dim_; // Dimension of each attention head - size_t max_seq_len_; // Maximum sequence length - double theta_; // Base frequency for rotary embeddings - Algo algo_; // RoPE algorithm type - DataType dtype_; // Data type for cache tables + size_t head_dim_; // Dimension of each attention head + size_t max_seq_len_; // Maximum sequence length + double theta_; // Base frequency for rotary embeddings + Algo algo_; // RoPE algorithm type + DataType dtype_; // Data type for cache tables + std::shared_ptr scaling_; // RoPE scaling type }; } // namespace infinicore::nn diff --git a/src/infinicore/nn/rope.cc b/src/infinicore/nn/rope.cc index 8b1eae696..26403a31c 100644 --- a/src/infinicore/nn/rope.cc +++ b/src/infinicore/nn/rope.cc @@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim, double theta, Algo algo, const DataType &dtype, - const Device &device) + const Device &device, + std::shared_ptr scaling) : head_dim_(head_dim), max_seq_len_(max_seq_len), theta_(theta), algo_(algo), - dtype_(dtype) { + dtype_(dtype), + scaling_(scaling) { if (head_dim % 2 != 0) { throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim)); } @@ -54,14 +56,30 @@ void RoPE::initialize_cache() { for (size_t j = 0; j < cache_dim; j++) { // GPT-J style inverse frequency: theta^(-2j/head_dim) // Compute directly in float to avoid double->float casting - float inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); + float inv_freq; + float table_factor = 1.0f; + if (scaling_ == nullptr) { + inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); + } else if (scaling_->type() == ScalingType::LONGROPE) { + std::shared_ptr lr = std::dynamic_pointer_cast(scaling_); + table_factor = lr->factor(); + float _ext; + if (pos < lr->original_max_position_embeddings()) { + _ext = lr->short_factor()[j]; + } else { + _ext = lr->long_factor()[j]; + } + inv_freq = 1.0f / (_ext * std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_))); + } else { + inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); + } // Compute angle: position * inverse_frequency float angle = static_cast(pos) * inv_freq; // Compute sin and cos directly on float - sin_data[pos * cache_dim + j] = std::sin(angle); - cos_data[pos * cache_dim + j] = std::cos(angle); + sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor; + cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor; } }