Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions include/infinicore/nn/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,47 @@ class RoPE : public Module {
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
};

enum class ScalingType {
DEFAULT = 0, // Default RoPE
LONGROPE = 1 // Long-RoPE
};

class ScalingConfig {
public:
virtual ~ScalingConfig() = default;
ScalingType type() const { return type_; }

protected:
ScalingType type_ = ScalingType::DEFAULT;
ScalingConfig(ScalingType type) : type_(type) {}
};

// longrope scaling
class LongRopeConfig : public ScalingConfig {
protected:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;

public:
LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f)
: ScalingConfig(ScalingType::LONGROPE),
short_factor_(short_factor),
long_factor_(long_factor),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
~LongRopeConfig() override = default;
size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }
};

/**
* @brief Construct a RoPE layer
*
Expand All @@ -26,13 +67,15 @@ class RoPE : public Module {
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
* @param scaling RoPE scaling type (default: nullptr)
*/
RoPE(size_t head_dim,
size_t max_seq_len,
double theta = 10000.0,
Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device());
const Device &device = Device(),
std::shared_ptr<ScalingConfig> scaling = nullptr);

/**
* @brief Forward pass: apply RoPE to a tensor
Expand Down Expand Up @@ -88,11 +131,12 @@ class RoPE : public Module {
private:
void initialize_cache();

size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
};

} // namespace infinicore::nn
28 changes: 23 additions & 5 deletions src/infinicore/nn/rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim,
double theta,
Algo algo,
const DataType &dtype,
const Device &device)
const Device &device,
std::shared_ptr<ScalingConfig> scaling)
: head_dim_(head_dim),
max_seq_len_(max_seq_len),
theta_(theta),
algo_(algo),
dtype_(dtype) {
dtype_(dtype),
scaling_(scaling) {
if (head_dim % 2 != 0) {
throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim));
}
Expand Down Expand Up @@ -54,14 +56,30 @@ void RoPE::initialize_cache() {
for (size_t j = 0; j < cache_dim; j++) {
// GPT-J style inverse frequency: theta^(-2j/head_dim)
// Compute directly in float to avoid double->float casting
float inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
float inv_freq;
float table_factor = 1.0f;
if (scaling_ == nullptr) {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
} else if (scaling_->type() == ScalingType::LONGROPE) {
std::shared_ptr<LongRopeConfig> lr = std::dynamic_pointer_cast<LongRopeConfig>(scaling_);
table_factor = lr->factor();
float _ext;
if (pos < lr->original_max_position_embeddings()) {
_ext = lr->short_factor()[j];
} else {
_ext = lr->long_factor()[j];
}
inv_freq = 1.0f / (_ext * std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)));
} else {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
}

// Compute angle: position * inverse_frequency
float angle = static_cast<float>(pos) * inv_freq;

// Compute sin and cos directly on float
sin_data[pos * cache_dim + j] = std::sin(angle);
cos_data[pos * cache_dim + j] = std::cos(angle);
sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor;
cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor;
}
}

Expand Down