From d6dbfd21e30ad7f5b71ea1f139b8fe89dc3602ea Mon Sep 17 00:00:00 2001 From: cearX Date: Mon, 12 Jan 2026 17:49:01 +0800 Subject: [PATCH] init --- include/infinicore_infer.h | 1 + .../infinicore_infer/models/deepseek_ocr.h | 186 +++ scripts/deepseek_ocr.py | 750 ++++++++++ scripts/libinfinicore_infer/__init__.py | 9 + scripts/libinfinicore_infer/deepseek_ocr.py | 251 ++++ src/cache_manager/opcache_manager.hpp | 6 + src/models/deepseek_ocr/deepseek_ocr.cpp | 1253 +++++++++++++++++ src/models/deepseek_ocr/deepseek_ocr_impl.hpp | 157 +++ .../deepseek_ocr/deepseek_ocr_weight.hpp | 454 ++++++ src/models/inference_context.cpp | 76 + src/models/inference_context.hpp | 27 + 11 files changed, 3170 insertions(+) create mode 100644 include/infinicore_infer/models/deepseek_ocr.h create mode 100644 scripts/deepseek_ocr.py create mode 100644 scripts/libinfinicore_infer/deepseek_ocr.py create mode 100644 src/models/deepseek_ocr/deepseek_ocr.cpp create mode 100644 src/models/deepseek_ocr/deepseek_ocr_impl.hpp create mode 100644 src/models/deepseek_ocr/deepseek_ocr_weight.hpp diff --git a/include/infinicore_infer.h b/include/infinicore_infer.h index 0bed7bc7..3816103c 100644 --- a/include/infinicore_infer.h +++ b/include/infinicore_infer.h @@ -5,6 +5,7 @@ #include "infinicore_infer/weights_loader.h" #include "infinicore_infer/models/deepseek.h" +#include "infinicore_infer/models/deepseek_ocr.h" #include "infinicore_infer/models/jiuge.h" #endif /* INFINICORE_INFER_H */ diff --git a/include/infinicore_infer/models/deepseek_ocr.h b/include/infinicore_infer/models/deepseek_ocr.h new file mode 100644 index 00000000..00d9761c --- /dev/null +++ b/include/infinicore_infer/models/deepseek_ocr.h @@ -0,0 +1,186 @@ +#ifndef MODEL_DEEPSEEK_OCR_H +#define MODEL_DEEPSEEK_OCR_H + +#include +#include +#include + +#include + +struct DeepSeekOCRModel; + +typedef struct +{ + infiniDtype_t dt_logits; + infiniDtype_t dt_norm; + // Layer counts + size_t n_dense_layer; // 第0层是dense + size_t n_sparse_layer; // 第1-11层是MoE + // Model dimensions + size_t d; // hidden_size: 1280 + size_t nh; // num_attention_heads: 1280 + size_t nkvh; // num_key_value_heads: 1280 + size_t dh; // head_dim: d/nh = 1 + // Dense MLP dimensions + size_t di_dense; // intermediate_size for dense layer: 6848 + // MoE dimensions + size_t di_moe; // moe_intermediate_size: 896 + size_t di_shared; // shared_expert_intermediate_size: 1792 + size_t nexperts; // n_routed_experts: 64 + size_t kexperts; // num_experts_per_tok: 6 + float routed_scale; // routed_scaling_factor: 1.0 + // Context and vocab + size_t dctx; // max_position_embeddings + size_t dvoc; // vocab_size: 129280 + // Normalization + float epsilon; // rms_norm_eps: 1e-6 + float theta; // rope_theta: 10000.0 + uint32_t end_token; // eos_token_id +} DeepSeekOCRMeta; + +typedef struct +{ + size_t n_dense_layer; + size_t n_sparse_layer; + infiniDtype_t dt_norm, dt_mat; + // 0 if linear weights are passed as W, any other value if passed as W^T + int transpose_linear_weights; + + // Embeddings + const void *input_embd; // [dvoc, d] + const void *output_norm; // [d] + const void *output_embd; // [dvoc, d] + + // Attention layers (all layers: n_dense_layer + n_sparse_layer) + const void *const *attn_norm; // nlayer * [d] + const void *const *attn_q; // nlayer * [d, d] or sharded + const void *const *attn_k; // nlayer * [d, d] or sharded + const void *const *attn_v; // nlayer * [d, d] or sharded + const void *const *attn_o; // nlayer * [d, d] or sharded + + // FFN layers + const void *const *ffn_norm; // nlayer * [d] + + // Dense MLP (layer 0) + const void *dense_gate; // [di_dense, d] + const void *dense_up; // [di_dense, d] + const void *dense_down; // [d, di_dense] + + // MoE layers (layer 1-11) + const void *const *moe_gate_weight; // n_sparse_layer * [nexperts, d] + const void *const *moe_gate_bias; // n_sparse_layer * [nexperts] + + // Shared experts + const void *const *moe_shared_gate; // n_sparse_layer * [di_shared, d] + const void *const *moe_shared_up; // n_sparse_layer * [di_shared, d] + const void *const *moe_shared_down; // n_sparse_layer * [d, di_shared] + + // Routed experts + const void *const *const *moe_experts_gate; // n_sparse_layer * nexperts * [di_moe, d] + const void *const *const *moe_experts_up; // n_sparse_layer * nexperts * [di_moe, d] + const void *const *const *moe_experts_down; // n_sparse_layer * nexperts * [d, di_moe] + + // Vision Encoder weights + // SAM ViT-B + const void *sam_patch_embed; + const void *sam_patch_embed_bias; + const void *const *sam_block_norm1; // 12 layers + const void *const *sam_block_attn_qkv; // 12 layers + const void *const *sam_block_attn_proj; // 12 layers + const void *const *sam_block_norm2; // 12 layers + const void *const *sam_block_mlp_fc1; // 12 layers + const void *const *sam_block_mlp_fc2; // 12 layers + const void *sam_neck_conv1; + const void *sam_neck_ln1; + const void *sam_neck_conv2; + const void *sam_neck_ln2; + + // CLIP-L + const void *clip_patch_embed; + const void *clip_patch_embed_bias; + const void *clip_position_embed; + const void *clip_pre_layernorm; + const void *const *clip_block_ln1; // 24 layers + const void *const *clip_block_attn_qkv; // 24 layers + const void *const *clip_block_attn_proj; // 24 layers + const void *const *clip_block_ln2; // 24 layers + const void *const *clip_block_mlp_fc1; // 24 layers + const void *const *clip_block_mlp_fc2; // 24 layers + + // Projector + const void *projector; // [2048, 1280] Linear projection + const void *image_newline; // [1280] Image row separator + const void *view_seperator; // [1280] View separator +} DeepSeekOCRWeights; + +//////////////////// APIs /////////////////////// + +/// @brief 创建DeepSeek-OCR模型 +/// @param device 协处理器种类 +/// @param ndev 协处理器数量 +/// @param dev_ids 协处理器编号,长度为 ndev +__C __export struct DeepSeekOCRModel * +createDeepSeekOCRModel(const DeepSeekOCRMeta *, + const DeepSeekOCRWeights *, + infiniDevice_t device, + int ndev, + const int *dev_ids); + +/// @brief 销毁模型 +__C __export void +destroyDeepSeekOCRModel(struct DeepSeekOCRModel *); + +/// @brief 批次推理一轮,并采样出新的 token +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param temperature 采样温度(0. 表示贪心采样) +/// @param topk 采样 topk(1 表示贪心采样) +/// @param topp 采样 topp +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +inferBatchDeepSeekOCR(struct DeepSeekOCRModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,输出 output embedding 后的 logits +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param logits 输出 logits,shape: [ntok, dvoc] +__C __export void +forwardBatchDeepSeekOCR(struct DeepSeekOCRModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits); + +/// @brief 使用预计算的embeddings进行推理(用于多模态输入) +/// @param inputs_embeds 输入embeddings,shape: [ntok, d] +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param temperature 采样温度 +/// @param topk 采样 topk +/// @param topp 采样 topp +/// @param output 输出 token 数组 +__C __export void +inferBatchDeepSeekOCRWithEmbeds(struct DeepSeekOCRModel *, + const void *inputs_embeds, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +#endif diff --git a/scripts/deepseek_ocr.py b/scripts/deepseek_ocr.py new file mode 100644 index 00000000..0c036b7c --- /dev/null +++ b/scripts/deepseek_ocr.py @@ -0,0 +1,750 @@ +import ctypes +from typing import List +from tqdm import tqdm +import os +import sys +import time +import json +import torch +import transformers + +from libinfinicore_infer import ( + DeepSeekOCRModel, + DeepSeekOCRMetaCStruct, + DeepSeekOCRWeightsCStruct, + DataType, + DeviceType, +) +from infer_task import InferTask, KVCache +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref + +torch.set_default_device("cpu") + + +def load_specific_tensor(model_dir, tensor_name): + """从safetensors模型加载特定tensor""" + import safetensors + + index_file = os.path.join(model_dir, "model.safetensors.index.json") + if not os.path.exists(index_file): + raise FileNotFoundError(f"Index file not found: {index_file}") + + with open(index_file, "r") as f: + index = json.load(f) + + weight_map = index["weight_map"] + if tensor_name not in weight_map: + raise KeyError(f"{tensor_name} not found in index") + + filename = weight_map[tensor_name] + tensor_file = os.path.join(model_dir, filename) + + with safetensors.safe_open(tensor_file, framework="pt", device="cpu") as f: + tensor = f.get_tensor(tensor_name) + return tensor + + +class DeepSeekOCRWeightsNaming: + """DeepSeek-OCR权重命名规则""" + + def __init__(self, n_dense=1, n_sparse=11): + self.n_dense = n_dense + self.n_sparse = n_sparse + + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + # Attention layers + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_q_proj(self, i): + return f"model.layers.{i}.self_attn.q_proj.weight" + + def attn_k_proj(self, i): + return f"model.layers.{i}.self_attn.k_proj.weight" + + def attn_v_proj(self, i): + return f"model.layers.{i}.self_attn.v_proj.weight" + + def attn_o_proj(self, i): + return f"model.layers.{i}.self_attn.o_proj.weight" + + # FFN + def ffn_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + # Dense MLP (layer 0) + def dense_gate_proj(self, i): + assert i < self.n_dense + return f"model.layers.{i}.mlp.gate_proj.weight" + + def dense_up_proj(self, i): + assert i < self.n_dense + return f"model.layers.{i}.mlp.up_proj.weight" + + def dense_down_proj(self, i): + assert i < self.n_dense + return f"model.layers.{i}.mlp.down_proj.weight" + + # MoE layers (layer 1-11) + def moe_gate_weight(self, i): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.gate.weight" + + def moe_gate_bias(self, i): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.gate.e_score_correction_bias" + + def moe_shared_gate_proj(self, i): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.shared_experts.gate_proj.weight" + + def moe_shared_up_proj(self, i): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.shared_experts.up_proj.weight" + + def moe_shared_down_proj(self, i): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.shared_experts.down_proj.weight" + + def moe_experts_gate_proj(self, i, e): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" + + def moe_experts_up_proj(self, i, e): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" + + def moe_experts_down_proj(self, i, e): + assert i >= self.n_dense + return f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" + + # Vision Encoder weights + def projector(self): + return "model.projector.layers.weight" + + def image_newline(self): + return "model.image_newline" + + def view_seperator(self): + return "model.view_seperator" + + # SAM ViT-B weights (12 layers) + def sam_patch_embed_weight(self): + return "model.sam_model.patch_embed.proj.weight" + + def sam_patch_embed_bias(self): + return "model.sam_model.patch_embed.proj.bias" + + def sam_block_norm1(self, layer): + return f"model.sam_model.blocks.{layer}.norm1.weight" + + def sam_block_attn_qkv(self, layer): + return f"model.sam_model.blocks.{layer}.attn.qkv.weight" + + def sam_block_attn_proj(self, layer): + return f"model.sam_model.blocks.{layer}.attn.proj.weight" + + def sam_block_norm2(self, layer): + return f"model.sam_model.blocks.{layer}.norm2.weight" + + def sam_block_mlp_fc1(self, layer): + return f"model.sam_model.blocks.{layer}.mlp.lin1.weight" + + def sam_block_mlp_fc2(self, layer): + return f"model.sam_model.blocks.{layer}.mlp.lin2.weight" + + def sam_neck_conv1(self): + return "model.sam_model.neck.0.weight" + + def sam_neck_ln1(self): + return "model.sam_model.neck.1.weight" + + def sam_neck_conv2(self): + return "model.sam_model.neck.2.weight" + + def sam_neck_ln2(self): + return "model.sam_model.neck.3.weight" + + # CLIP-L weights (24 layers) + def clip_patch_embed_weight(self): + return "model.vision_model.embeddings.patch_embedding.weight" + + def clip_patch_embed_bias(self): + return "model.vision_model.embeddings.patch_embedding.bias" + + def clip_position_embed(self): + return "model.vision_model.embeddings.position_embedding.weight" + + def clip_pre_layernorm(self): + return "model.vision_model.pre_layrnorm.weight" + + def clip_block_ln1(self, layer): + return f"model.vision_model.encoder.layers.{layer}.layer_norm1.weight" + + def clip_block_attn_qkv(self, layer): + return f"model.vision_model.encoder.layers.{layer}.self_attn.qkv.weight" + + def clip_block_attn_proj(self, layer): + return f"model.vision_model.encoder.layers.{layer}.self_attn.out_proj.weight" + + def clip_block_ln2(self, layer): + return f"model.vision_model.encoder.layers.{layer}.layer_norm2.weight" + + def clip_block_mlp_fc1(self, layer): + return f"model.vision_model.encoder.layers.{layer}.mlp.fc1.weight" + + def clip_block_mlp_fc2(self, layer): + return f"model.vision_model.encoder.layers.{layer}.mlp.fc2.weight" + + +class DeepSeekOCRMeta(DeepSeekOCRMetaCStruct): + def __init__(self, config, dtype=torch.bfloat16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_BF16 + + n_dense = config.get("first_k_dense_replace", 1) + n_sparse = config["num_hidden_layers"] - n_dense + + super().__init__( + dt_logits=dt_, + dt_norm=dt_, + n_dense_layer=n_dense, + n_sparse_layer=n_sparse, + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=config.get("num_key_value_heads", config["num_attention_heads"]), + dh=config["hidden_size"] // config["num_attention_heads"], + di_dense=config.get("intermediate_size", 6848), + di_moe=config.get("moe_intermediate_size", 896), + di_shared=config.get("shared_expert_intermediate_size", 1792), + nexperts=config.get("n_routed_experts", 64), + kexperts=config.get("num_experts_per_tok", 6), + routed_scale=config.get("routed_scaling_factor", 1.0), + dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens, + dvoc=config["vocab_size"], + epsilon=config.get("rms_norm_eps", 1e-6), + theta=config.get("rope_theta", 10000.0), + end_token=config.get("eos_token_id", 1), + ) + self.torch_dtype_logits = dtype + + +def load_deepseek_ocr_weights( + meta: DeepSeekOCRMeta, weights, model_path: str, ndev: int +): + """加载DeepSeek-OCR所有权重""" + from ctypes import cast + + names = DeepSeekOCRWeightsNaming() + nlayer = meta.n_dense_layer + meta.n_sparse_layer + + # 基础embeddings + input_embd = load_specific_tensor(model_path, names.input_embd()).to(meta.torch_dtype_logits) + weights.input_embd = input_embd.data_ptr() + + output_norm = load_specific_tensor(model_path, names.output_norm()).to(meta.torch_dtype_logits) + weights.output_norm = output_norm.data_ptr() + + output_embd = load_specific_tensor(model_path, names.output_embd()).to(meta.torch_dtype_logits) + weights.output_embd = output_embd.data_ptr() + + # Attention & FFN norm (所有层) + attn_norm_ptrs = (c_void_p * nlayer)() + attn_q_ptrs = (c_void_p * nlayer)() + attn_k_ptrs = (c_void_p * nlayer)() + attn_v_ptrs = (c_void_p * nlayer)() + attn_o_ptrs = (c_void_p * nlayer)() + ffn_norm_ptrs = (c_void_p * nlayer)() + + layer_tensors = [] # 保持引用 + + for i in tqdm(range(nlayer), desc="Loading layers"): + attn_norm = load_specific_tensor(model_path, names.attn_norm(i)).to(meta.torch_dtype_logits) + attn_norm_ptrs[i] = attn_norm.data_ptr() + layer_tensors.append(attn_norm) + + attn_q = load_specific_tensor(model_path, names.attn_q_proj(i)).to(meta.torch_dtype_logits) + attn_q_ptrs[i] = attn_q.data_ptr() + layer_tensors.append(attn_q) + + attn_k = load_specific_tensor(model_path, names.attn_k_proj(i)).to(meta.torch_dtype_logits) + attn_k_ptrs[i] = attn_k.data_ptr() + layer_tensors.append(attn_k) + + attn_v = load_specific_tensor(model_path, names.attn_v_proj(i)).to(meta.torch_dtype_logits) + attn_v_ptrs[i] = attn_v.data_ptr() + layer_tensors.append(attn_v) + + attn_o = load_specific_tensor(model_path, names.attn_o_proj(i)).to(meta.torch_dtype_logits) + attn_o_ptrs[i] = attn_o.data_ptr() + layer_tensors.append(attn_o) + + ffn_norm = load_specific_tensor(model_path, names.ffn_norm(i)).to(meta.torch_dtype_logits) + ffn_norm_ptrs[i] = ffn_norm.data_ptr() + layer_tensors.append(ffn_norm) + + weights.attn_norm = cast(attn_norm_ptrs, POINTER(c_void_p)) + weights.attn_q = cast(attn_q_ptrs, POINTER(c_void_p)) + weights.attn_k = cast(attn_k_ptrs, POINTER(c_void_p)) + weights.attn_v = cast(attn_v_ptrs, POINTER(c_void_p)) + weights.attn_o = cast(attn_o_ptrs, POINTER(c_void_p)) + weights.ffn_norm = cast(ffn_norm_ptrs, POINTER(c_void_p)) + + # Dense MLP (第0层) + dense_gate = load_specific_tensor(model_path, names.dense_gate_proj(0)).to(meta.torch_dtype_logits) + weights.dense_gate = dense_gate.data_ptr() + + dense_up = load_specific_tensor(model_path, names.dense_up_proj(0)).to(meta.torch_dtype_logits) + weights.dense_up = dense_up.data_ptr() + + dense_down = load_specific_tensor(model_path, names.dense_down_proj(0)).to(meta.torch_dtype_logits) + weights.dense_down = dense_down.data_ptr() + + # MoE (第1-11层) + n_sparse = meta.n_sparse_layer + moe_gate_weight_ptrs = (c_void_p * n_sparse)() + moe_gate_bias_ptrs = (c_void_p * n_sparse)() + moe_shared_gate_ptrs = (c_void_p * n_sparse)() + moe_shared_up_ptrs = (c_void_p * n_sparse)() + moe_shared_down_ptrs = (c_void_p * n_sparse)() + + moe_tensors = [] + for i in tqdm(range(1, nlayer), desc="Loading MoE layers"): + moe_idx = i - 1 + + gate_w = load_specific_tensor(model_path, names.moe_gate_weight(i)).to(meta.torch_dtype_logits) + moe_gate_weight_ptrs[moe_idx] = gate_w.data_ptr() + moe_tensors.append(gate_w) + + gate_b = load_specific_tensor(model_path, names.moe_gate_bias(i)).to(meta.torch_dtype_logits) + moe_gate_bias_ptrs[moe_idx] = gate_b.data_ptr() + moe_tensors.append(gate_b) + + shared_g = load_specific_tensor(model_path, names.moe_shared_gate_proj(i)).to(meta.torch_dtype_logits) + moe_shared_gate_ptrs[moe_idx] = shared_g.data_ptr() + moe_tensors.append(shared_g) + + shared_u = load_specific_tensor(model_path, names.moe_shared_up_proj(i)).to(meta.torch_dtype_logits) + moe_shared_up_ptrs[moe_idx] = shared_u.data_ptr() + moe_tensors.append(shared_u) + + shared_d = load_specific_tensor(model_path, names.moe_shared_down_proj(i)).to(meta.torch_dtype_logits) + moe_shared_down_ptrs[moe_idx] = shared_d.data_ptr() + moe_tensors.append(shared_d) + + weights.moe_gate_weight = cast(moe_gate_weight_ptrs, POINTER(c_void_p)) + weights.moe_gate_bias = cast(moe_gate_bias_ptrs, POINTER(c_void_p)) + weights.moe_shared_gate = cast(moe_shared_gate_ptrs, POINTER(c_void_p)) + weights.moe_shared_up = cast(moe_shared_up_ptrs, POINTER(c_void_p)) + weights.moe_shared_down = cast(moe_shared_down_ptrs, POINTER(c_void_p)) + + # Routed experts + nexperts = meta.nexperts + expert_gate_ptrs_per_layer = [] + expert_up_ptrs_per_layer = [] + expert_down_ptrs_per_layer = [] + expert_tensors = [] + + for moe_layer in tqdm(range(1, nlayer), desc="Loading MoE experts"): + expert_gate_ptrs = (c_void_p * nexperts)() + expert_up_ptrs = (c_void_p * nexperts)() + expert_down_ptrs = (c_void_p * nexperts)() + + for e in range(nexperts): + gate = load_specific_tensor(model_path, names.moe_experts_gate_proj(moe_layer, e)).to(meta.torch_dtype_logits) + expert_gate_ptrs[e] = gate.data_ptr() + expert_tensors.append(gate) + + up = load_specific_tensor(model_path, names.moe_experts_up_proj(moe_layer, e)).to(meta.torch_dtype_logits) + expert_up_ptrs[e] = up.data_ptr() + expert_tensors.append(up) + + down = load_specific_tensor(model_path, names.moe_experts_down_proj(moe_layer, e)).to(meta.torch_dtype_logits) + expert_down_ptrs[e] = down.data_ptr() + expert_tensors.append(down) + + expert_gate_ptrs_per_layer.append(cast(expert_gate_ptrs, POINTER(c_void_p))) + expert_up_ptrs_per_layer.append(cast(expert_up_ptrs, POINTER(c_void_p))) + expert_down_ptrs_per_layer.append(cast(expert_down_ptrs, POINTER(c_void_p))) + + weights.moe_experts_gate = cast((POINTER(c_void_p) * n_sparse)(*expert_gate_ptrs_per_layer), POINTER(POINTER(c_void_p))) + weights.moe_experts_up = cast((POINTER(c_void_p) * n_sparse)(*expert_up_ptrs_per_layer), POINTER(POINTER(c_void_p))) + weights.moe_experts_down = cast((POINTER(c_void_p) * n_sparse)(*expert_down_ptrs_per_layer), POINTER(POINTER(c_void_p))) + + # 视觉编码器权重 + vision_tensors = [] + + # SAM ViT-B (12 layers) + sam_patch_embed = load_specific_tensor(model_path, names.sam_patch_embed_weight()).to(meta.torch_dtype_logits) + weights.sam_patch_embed = sam_patch_embed.data_ptr() + vision_tensors.append(sam_patch_embed) + + sam_patch_embed_bias = load_specific_tensor(model_path, names.sam_patch_embed_bias()).to(meta.torch_dtype_logits) + weights.sam_patch_embed_bias = sam_patch_embed_bias.data_ptr() + vision_tensors.append(sam_patch_embed_bias) + + sam_block_norm1_ptrs = (c_void_p * 12)() + sam_block_attn_qkv_ptrs = (c_void_p * 12)() + sam_block_attn_proj_ptrs = (c_void_p * 12)() + sam_block_norm2_ptrs = (c_void_p * 12)() + sam_block_mlp_fc1_ptrs = (c_void_p * 12)() + sam_block_mlp_fc2_ptrs = (c_void_p * 12)() + + for layer in tqdm(range(12), desc="Loading SAM blocks"): + norm1 = load_specific_tensor(model_path, names.sam_block_norm1(layer)).to(meta.torch_dtype_logits) + sam_block_norm1_ptrs[layer] = norm1.data_ptr() + vision_tensors.append(norm1) + + qkv = load_specific_tensor(model_path, names.sam_block_attn_qkv(layer)).to(meta.torch_dtype_logits) + sam_block_attn_qkv_ptrs[layer] = qkv.data_ptr() + vision_tensors.append(qkv) + + proj = load_specific_tensor(model_path, names.sam_block_attn_proj(layer)).to(meta.torch_dtype_logits) + sam_block_attn_proj_ptrs[layer] = proj.data_ptr() + vision_tensors.append(proj) + + norm2 = load_specific_tensor(model_path, names.sam_block_norm2(layer)).to(meta.torch_dtype_logits) + sam_block_norm2_ptrs[layer] = norm2.data_ptr() + vision_tensors.append(norm2) + + fc1 = load_specific_tensor(model_path, names.sam_block_mlp_fc1(layer)).to(meta.torch_dtype_logits) + sam_block_mlp_fc1_ptrs[layer] = fc1.data_ptr() + vision_tensors.append(fc1) + + fc2 = load_specific_tensor(model_path, names.sam_block_mlp_fc2(layer)).to(meta.torch_dtype_logits) + sam_block_mlp_fc2_ptrs[layer] = fc2.data_ptr() + vision_tensors.append(fc2) + + weights.sam_block_norm1 = cast(sam_block_norm1_ptrs, POINTER(c_void_p)) + weights.sam_block_attn_qkv = cast(sam_block_attn_qkv_ptrs, POINTER(c_void_p)) + weights.sam_block_attn_proj = cast(sam_block_attn_proj_ptrs, POINTER(c_void_p)) + weights.sam_block_norm2 = cast(sam_block_norm2_ptrs, POINTER(c_void_p)) + weights.sam_block_mlp_fc1 = cast(sam_block_mlp_fc1_ptrs, POINTER(c_void_p)) + weights.sam_block_mlp_fc2 = cast(sam_block_mlp_fc2_ptrs, POINTER(c_void_p)) + + sam_neck_conv1 = load_specific_tensor(model_path, names.sam_neck_conv1()).to(meta.torch_dtype_logits) + weights.sam_neck_conv1 = sam_neck_conv1.data_ptr() + vision_tensors.append(sam_neck_conv1) + + sam_neck_ln1 = load_specific_tensor(model_path, names.sam_neck_ln1()).to(meta.torch_dtype_logits) + weights.sam_neck_ln1 = sam_neck_ln1.data_ptr() + vision_tensors.append(sam_neck_ln1) + + sam_neck_conv2 = load_specific_tensor(model_path, names.sam_neck_conv2()).to(meta.torch_dtype_logits) + weights.sam_neck_conv2 = sam_neck_conv2.data_ptr() + vision_tensors.append(sam_neck_conv2) + + sam_neck_ln2 = load_specific_tensor(model_path, names.sam_neck_ln2()).to(meta.torch_dtype_logits) + weights.sam_neck_ln2 = sam_neck_ln2.data_ptr() + vision_tensors.append(sam_neck_ln2) + + # CLIP-L (24 layers) + clip_patch_embed = load_specific_tensor(model_path, names.clip_patch_embed_weight()).to(meta.torch_dtype_logits) + weights.clip_patch_embed = clip_patch_embed.data_ptr() + vision_tensors.append(clip_patch_embed) + + clip_patch_embed_bias = load_specific_tensor(model_path, names.clip_patch_embed_bias()).to(meta.torch_dtype_logits) + weights.clip_patch_embed_bias = clip_patch_embed_bias.data_ptr() + vision_tensors.append(clip_patch_embed_bias) + + clip_position_embed = load_specific_tensor(model_path, names.clip_position_embed()).to(meta.torch_dtype_logits) + weights.clip_position_embed = clip_position_embed.data_ptr() + vision_tensors.append(clip_position_embed) + + clip_pre_layernorm = load_specific_tensor(model_path, names.clip_pre_layernorm()).to(meta.torch_dtype_logits) + weights.clip_pre_layernorm = clip_pre_layernorm.data_ptr() + vision_tensors.append(clip_pre_layernorm) + + clip_block_ln1_ptrs = (c_void_p * 24)() + clip_block_attn_qkv_ptrs = (c_void_p * 24)() + clip_block_attn_proj_ptrs = (c_void_p * 24)() + clip_block_ln2_ptrs = (c_void_p * 24)() + clip_block_mlp_fc1_ptrs = (c_void_p * 24)() + clip_block_mlp_fc2_ptrs = (c_void_p * 24)() + + for layer in tqdm(range(24), desc="Loading CLIP blocks"): + ln1 = load_specific_tensor(model_path, names.clip_block_ln1(layer)).to(meta.torch_dtype_logits) + clip_block_ln1_ptrs[layer] = ln1.data_ptr() + vision_tensors.append(ln1) + + qkv = load_specific_tensor(model_path, names.clip_block_attn_qkv(layer)).to(meta.torch_dtype_logits) + clip_block_attn_qkv_ptrs[layer] = qkv.data_ptr() + vision_tensors.append(qkv) + + proj = load_specific_tensor(model_path, names.clip_block_attn_proj(layer)).to(meta.torch_dtype_logits) + clip_block_attn_proj_ptrs[layer] = proj.data_ptr() + vision_tensors.append(proj) + + ln2 = load_specific_tensor(model_path, names.clip_block_ln2(layer)).to(meta.torch_dtype_logits) + clip_block_ln2_ptrs[layer] = ln2.data_ptr() + vision_tensors.append(ln2) + + fc1 = load_specific_tensor(model_path, names.clip_block_mlp_fc1(layer)).to(meta.torch_dtype_logits) + clip_block_mlp_fc1_ptrs[layer] = fc1.data_ptr() + vision_tensors.append(fc1) + + fc2 = load_specific_tensor(model_path, names.clip_block_mlp_fc2(layer)).to(meta.torch_dtype_logits) + clip_block_mlp_fc2_ptrs[layer] = fc2.data_ptr() + vision_tensors.append(fc2) + + weights.clip_block_ln1 = cast(clip_block_ln1_ptrs, POINTER(c_void_p)) + weights.clip_block_attn_qkv = cast(clip_block_attn_qkv_ptrs, POINTER(c_void_p)) + weights.clip_block_attn_proj = cast(clip_block_attn_proj_ptrs, POINTER(c_void_p)) + weights.clip_block_ln2 = cast(clip_block_ln2_ptrs, POINTER(c_void_p)) + weights.clip_block_mlp_fc1 = cast(clip_block_mlp_fc1_ptrs, POINTER(c_void_p)) + weights.clip_block_mlp_fc2 = cast(clip_block_mlp_fc2_ptrs, POINTER(c_void_p)) + + # Projector + projector = load_specific_tensor(model_path, names.projector()).to(meta.torch_dtype_logits) + weights.projector = projector.data_ptr() + vision_tensors.append(projector) + + image_newline = load_specific_tensor(model_path, names.image_newline()).to(meta.torch_dtype_logits) + weights.image_newline = image_newline.data_ptr() + vision_tensors.append(image_newline) + + view_seperator = load_specific_tensor(model_path, names.view_seperator()).to(meta.torch_dtype_logits) + weights.view_seperator = view_seperator.data_ptr() + vision_tensors.append(view_seperator) + + return layer_tensors + moe_tensors + expert_tensors + vision_tensors + [ + input_embd, output_norm, output_embd, dense_gate, dense_up, dense_down + ] + + +class DeepSeekOCRBatchedTask: + def __init__(self, tasks: List[InferTask]): + from libinfinicore_infer import KVCacheCStruct + + self.tasks = tasks + self.nreq = len(tasks) + + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +class DeepSeekOCRForCausalLM: + def __init__( + self, + model_dir_path, + device=DeviceType.DEVICE_TYPE_CPU, + ndev=1, + max_tokens=None, + ): + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + print(f"Loading model from: {model_dir_path}") + + # 创建meta + self.meta = DeepSeekOCRMeta(config, max_tokens=max_tokens, dtype=torch.bfloat16) + + # 加载tokenizer + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + + # 创建C++模型并加载权重 + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + + self.model_instance = DeepSeekOCRModel() + + # 创建权重结构 + from ctypes import Structure, POINTER + weights = DeepSeekOCRWeightsCStruct() + weights.n_dense_layer = self.meta.n_dense_layer + weights.n_sparse_layer = self.meta.n_sparse_layer + weights.dt_norm = self.meta.dt_norm + weights.dt_mat = self.meta.dt_logits + weights.transpose_linear_weights = 1 # PyTorch format is transposed + + # 加载所有权重 + print("Loading model weights...") + self.weight_tensors = load_deepseek_ocr_weights(self.meta, weights, model_dir_path, ndev) + + # 创建模型 + self.model_ptr = self.model_instance.create_model( + byref(self.meta), + byref(weights), + device, + ndev, + dev_ids, + ) + + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + nlayer = self.meta.n_dense_layer + self.meta.n_sparse_layer + return self.model_instance.create_kv_cache( + nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.model_ptr.contents.resources[0].device if hasattr(self.model_ptr.contents, 'resources') else DeviceType.DEVICE_TYPE_CPU, + (c_int * 1)(0), + 1, + ) + + def drop_kv_cache(self, kv_cache): + self.model_instance.drop_kv_cache(kv_cache) + + + def generate(self, input_content, max_steps, image_path=None, + topp_=1.0, topk_=1, temperature_=1.0): + """生成文本(支持多模态输入)""" + # 构建prompt + if image_path: + input_content = f"\n{input_content}" + + # Tokenize + tokens = self.tokenizer.encode(input_content) + + # 创建推理任务 + infer_task = InferTask( + 0, tokens, self.max_context_len(), + temperature_, topk_, topp_, self.eos_token_id + ) + infer_task.bind_kvcache(KVCache(self)) + + print(input_content, end="", flush=True) + + steps = 0 + total_time = 0 + output_content = "" + + # 生成循环 + for step_i in range(max_steps): + start_time = time.time() + + # 调用推理 + output_tokens = self.batch_infer_one_round([infer_task]) + + end_time = time.time() + steps += 1 + + # 解码 + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str + print(output_str, end="", flush=True) + + # 检查结束 + if output_tokens[0] in self.eos_token_id: + break + + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / max(steps - 1, 1) + print(f"Time per step: {avg_time:.3f}ms") + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = DeepSeekOCRBatchedTask(tasks) + self.model_instance.infer_batch( + self.model_ptr, *(batch_inputs.input_args()), output + ) + return list(output) + + def destroy_model_instance(self): + self.model_instance.destroy_model(self.model_ptr) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python deepseek_ocr.py [--cpu | --nvidia | --cambricon | --ascend] [n_device]" + ) + sys.exit(1) + + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + else: + print( + "Usage: python deepseek_ocr.py [--cpu | --nvidia | --cambricon | --ascend] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = DeepSeekOCRForCausalLM(model_path, device_type, ndev, max_tokens=2048) + + # 测试纯文本生成 + output, avg_time = model.generate("北京是中国的首都吗?", 50) + print(f"\nAverage time per step: {avg_time:.3f}ms") + + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 8fc5f4db..be828179 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -8,6 +8,11 @@ DeepSeekV3WeightLoaderCStruct, DeepSeekV3CacheCStruct, ) +from .deepseek_ocr import ( + DeepSeekOCRModel, + DeepSeekOCRMetaCStruct, + DeepSeekOCRWeightsCStruct, +) __all__ = [ "DataType", @@ -23,5 +28,9 @@ "DeepSeekV3MetaCStruct", "DeepSeekV3WeightsCStruct", "DeepSeekV3WeightLoaderCStruct", + "DeepSeekV3CacheCStruct", + "DeepSeekOCRModel", + "DeepSeekOCRMetaCStruct", + "DeepSeekOCRWeightsCStruct", "ModelRegister", ] diff --git a/scripts/libinfinicore_infer/deepseek_ocr.py b/scripts/libinfinicore_infer/deepseek_ocr.py new file mode 100644 index 00000000..98b4963c --- /dev/null +++ b/scripts/libinfinicore_infer/deepseek_ocr.py @@ -0,0 +1,251 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref + + +class DeepSeekOCRMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("dt_norm", DataType), + # Layer counts + ("n_dense_layer", c_size_t), # 第0层是dense + ("n_sparse_layer", c_size_t), # 第1-11层是MoE + # Model dimensions + ("d", c_size_t), # hidden_size: 1280 + ("nh", c_size_t), # num_attention_heads: 1280 + ("nkvh", c_size_t), # num_key_value_heads: 1280 + ("dh", c_size_t), # head_dim: d/nh = 1 + # Dense MLP dimensions + ("di_dense", c_size_t), # intermediate_size for dense layer: 6848 + # MoE dimensions + ("di_moe", c_size_t), # moe_intermediate_size: 896 + ("di_shared", c_size_t), # shared_expert_intermediate_size: 1792 + ("nexperts", c_size_t), # n_routed_experts: 64 + ("kexperts", c_size_t), # num_experts_per_tok: 6 + ("routed_scale", c_float), # routed_scaling_factor: 1.0 + # Context and vocab + ("dctx", c_size_t), # max_position_embeddings + ("dvoc", c_size_t), # vocab_size: 129280 + # Normalization + ("epsilon", c_float), # rms_norm_eps: 1e-6 + ("theta", c_float), # rope_theta: 10000.0 + ("end_token", c_uint), # eos_token_id + ] + + +class DeepSeekOCRWeightsCStruct(Structure): + _fields_ = [ + ("n_dense_layer", c_size_t), + ("n_sparse_layer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + # Embeddings + ("input_embd", c_void_p), # [dvoc, d] + ("output_norm", c_void_p), # [d] + ("output_embd", c_void_p), # [dvoc, d] + # Attention layers (all layers) + ("attn_norm", POINTER(c_void_p)), # nlayer * [d] + ("attn_q", POINTER(c_void_p)), # nlayer * [d, d] 或分片 + ("attn_k", POINTER(c_void_p)), # nlayer * [d, d] 或分片 + ("attn_v", POINTER(c_void_p)), # nlayer * [d, d] 或分片 + ("attn_o", POINTER(c_void_p)), # nlayer * [d, d] 或分片 + # FFN layers + ("ffn_norm", POINTER(c_void_p)), # nlayer * [d] + # Dense MLP (layer 0) + ("dense_gate", c_void_p), # [di_dense, d] + ("dense_up", c_void_p), # [di_dense, d] + ("dense_down", c_void_p), # [d, di_dense] + # MoE layers (layer 1-11) + ("moe_gate_weight", POINTER(c_void_p)), # n_sparse_layer * [nexperts, d] + ("moe_gate_bias", POINTER(c_void_p)), # n_sparse_layer * [nexperts] + # Shared experts + ("moe_shared_gate", POINTER(c_void_p)), # n_sparse_layer * [di_shared, d] + ("moe_shared_up", POINTER(c_void_p)), # n_sparse_layer * [di_shared, d] + ("moe_shared_down", POINTER(c_void_p)), # n_sparse_layer * [d, di_shared] + # Routed experts + ("moe_experts_gate", POINTER(POINTER(c_void_p))), # n_sparse_layer * nexperts * [di_moe, d] + ("moe_experts_up", POINTER(POINTER(c_void_p))), # n_sparse_layer * nexperts * [di_moe, d] + ("moe_experts_down", POINTER(POINTER(c_void_p))), # n_sparse_layer * nexperts * [d, di_moe] + # Vision weights - SAM + ("sam_patch_embed", c_void_p), + ("sam_patch_embed_bias", c_void_p), + ("sam_block_norm1", POINTER(c_void_p)), # 12 layers + ("sam_block_attn_qkv", POINTER(c_void_p)), + ("sam_block_attn_proj", POINTER(c_void_p)), + ("sam_block_norm2", POINTER(c_void_p)), + ("sam_block_mlp_fc1", POINTER(c_void_p)), + ("sam_block_mlp_fc2", POINTER(c_void_p)), + ("sam_neck_conv1", c_void_p), + ("sam_neck_ln1", c_void_p), + ("sam_neck_conv2", c_void_p), + ("sam_neck_ln2", c_void_p), + # Vision weights - CLIP + ("clip_patch_embed", c_void_p), + ("clip_patch_embed_bias", c_void_p), + ("clip_position_embed", c_void_p), + ("clip_pre_layernorm", c_void_p), + ("clip_block_ln1", POINTER(c_void_p)), # 24 layers + ("clip_block_attn_qkv", POINTER(c_void_p)), + ("clip_block_attn_proj", POINTER(c_void_p)), + ("clip_block_ln2", POINTER(c_void_p)), + ("clip_block_mlp_fc1", POINTER(c_void_p)), + ("clip_block_mlp_fc2", POINTER(c_void_p)), + # Projector + ("projector", c_void_p), # [2048, 1280] + ("image_newline", c_void_p), # [1280] + ("view_seperator", c_void_p), # [1280] + ] + + +class DeepSeekOCRModelCStruct(Structure): + pass + + +@register_model +class DeepSeekOCRModel(BaseModel): + @classmethod + def register_lib(cls, lib): + lib.createDeepSeekOCRModel.restype = POINTER(DeepSeekOCRModelCStruct) + lib.createDeepSeekOCRModel.argtypes = [ + POINTER(DeepSeekOCRMetaCStruct), + POINTER(DeepSeekOCRWeightsCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + + lib.destroyDeepSeekOCRModel.argtypes = [POINTER(DeepSeekOCRModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + lib.inferBatchDeepSeekOCR.argtypes = [ + POINTER(DeepSeekOCRModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.forwardBatchDeepSeekOCR.argtypes = [ + POINTER(DeepSeekOCRModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_void_p, + ] + + # 新增: 用于注入图像特征的接口 + lib.inferBatchDeepSeekOCRWithEmbeds.argtypes = [ + POINTER(DeepSeekOCRModelCStruct), + c_void_p, # inputs_embeds + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + def create_model(self, meta, weights, device_type, ndev, dev_ids): + return self.lib.createDeepSeekOCRModel(meta, weights, device_type, ndev, dev_ids) + + def destroy_model(self, model): + self.lib.destroyDeepSeekOCRModel(model) + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchDeepSeekOCR( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) + + def forward_batch( + self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ): + self.lib.forwardBatchDeepSeekOCR( + model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ) + + def infer_batch_with_embeds( + self, + model, + inputs_embeds, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchDeepSeekOCRWithEmbeds( + model, + inputs_embeds, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 4c49e961..cdaeb18c 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -154,6 +154,9 @@ class CacheManager { public: DECLARE_OP_CACHE(Add) DECLARE_OP_CACHE(RMSNorm) + DECLARE_OP_CACHE(LayerNorm) + DECLARE_OP_CACHE(GELU) + DECLARE_OP_CACHE(Conv2d) DECLARE_OP_CACHE(Gemm) DECLARE_OP_CACHE(RoPE) DECLARE_OP_CACHE(Rearrange) @@ -166,6 +169,9 @@ class CacheManager { CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)), + LayerNorm_cache(capacity, DESTROY_FUNC(LayerNorm)), + GELU_cache(capacity, DESTROY_FUNC(GELU)), + Conv2d_cache(capacity, DESTROY_FUNC(Conv2d)), Gemm_cache(capacity, DESTROY_FUNC(Gemm)), RoPE_cache(capacity, DESTROY_FUNC(RoPE)), Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)), diff --git a/src/models/deepseek_ocr/deepseek_ocr.cpp b/src/models/deepseek_ocr/deepseek_ocr.cpp new file mode 100644 index 00000000..63cd8c2c --- /dev/null +++ b/src/models/deepseek_ocr/deepseek_ocr.cpp @@ -0,0 +1,1253 @@ +#include "deepseek_ocr_impl.hpp" +#include "deepseek_ocr_weight.hpp" + +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include "../inference_context.hpp" +#include "infinicore_infer.h" + +#include +#include +#include +#include + +// ================ 创建设备资源 ================ + +void createDeviceResource(DeepSeekOCRDeviceResource *rsrc, + const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + infiniDevice_t device, + int idev, + int ndev, + int dev_id, + infinicclComm_t comm) { + RUN_INFINI(infinirtSetDevice(device, dev_id)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + size_t nlayer = meta->n_dense_layer + meta->n_sparse_layer; + + // 加载基础权重 + auto w_in_embd = getInEmbd(meta, weights); + auto w_out_norm = getOutNorm(meta, weights); + auto w_out_embd = getOutEmbd(meta, weights); + auto sin_table = getSinTable(meta); + auto cos_table = getCosTable(meta); + + // 加载所有层的attention权重 + std::vector> w_attn_norm, w_attn_q, w_attn_k, w_attn_v, w_attn_o; + std::vector> w_ffn_norm; + + for (size_t layer = 0; layer < nlayer; layer++) { + w_attn_norm.push_back(getAttnNorm(meta, weights, layer)); + w_attn_q.push_back(getAttnQ(meta, weights, layer, idev, ndev)); + w_attn_k.push_back(getAttnK(meta, weights, layer, idev, ndev)); + w_attn_v.push_back(getAttnV(meta, weights, layer, idev, ndev)); + w_attn_o.push_back(getAttnO(meta, weights, layer, idev, ndev)); + w_ffn_norm.push_back(getFFNNorm(meta, weights, layer)); + } + + // 加载Dense MLP权重 (第0层) + auto w_dense_gate = getDenseGate(meta, weights, idev, ndev); + auto w_dense_up = getDenseUp(meta, weights, idev, ndev); + auto w_dense_down = getDenseDown(meta, weights, idev, ndev); + + // 加载MoE权重 (第1-11层) + std::vector> w_moe_gate_weight, w_moe_gate_bias; + std::vector> w_moe_shared_gate, w_moe_shared_up, w_moe_shared_down; + std::vector>> w_moe_experts_gate, w_moe_experts_up, w_moe_experts_down; + + for (size_t i = 0; i < meta->n_sparse_layer; i++) { + w_moe_gate_weight.push_back(getMoEGateWeight(meta, weights, i)); + w_moe_gate_bias.push_back(getMoEGateBias(meta, weights, i)); + + w_moe_shared_gate.push_back(getMoESharedGate(meta, weights, i, idev, ndev)); + w_moe_shared_up.push_back(getMoESharedUp(meta, weights, i, idev, ndev)); + w_moe_shared_down.push_back(getMoESharedDown(meta, weights, i, idev, ndev)); + + // 加载所有routed experts + std::vector> experts_gate, experts_up, experts_down; + for (size_t e = 0; e < meta->nexperts; e++) { + experts_gate.push_back(getMoEExpertsGate(meta, weights, i, e, idev, ndev)); + experts_up.push_back(getMoEExpertsUp(meta, weights, i, e, idev, ndev)); + experts_down.push_back(getMoEExpertsDown(meta, weights, i, e, idev, ndev)); + } + w_moe_experts_gate.push_back(experts_gate); + w_moe_experts_up.push_back(experts_up); + w_moe_experts_down.push_back(experts_down); + } + + // 加载视觉编码器权重 + // SAM ViT-B + auto w_sam_patch_embed = getSAMPatchEmbed(meta, weights); + auto w_sam_patch_embed_bias = getSAMPatchEmbedBias(meta, weights); + + std::vector> w_sam_block_norm1, w_sam_block_attn_qkv, w_sam_block_attn_proj; + std::vector> w_sam_block_norm2, w_sam_block_mlp_fc1, w_sam_block_mlp_fc2; + for (size_t layer = 0; layer < 12; layer++) { + w_sam_block_norm1.push_back(getSAMBlockNorm1(meta, weights, layer)); + w_sam_block_attn_qkv.push_back(getSAMBlockAttnQKV(meta, weights, layer)); + w_sam_block_attn_proj.push_back(getSAMBlockAttnProj(meta, weights, layer)); + w_sam_block_norm2.push_back(getSAMBlockNorm2(meta, weights, layer)); + w_sam_block_mlp_fc1.push_back(getSAMBlockMLPFC1(meta, weights, layer)); + w_sam_block_mlp_fc2.push_back(getSAMBlockMLPFC2(meta, weights, layer)); + } + + auto w_sam_neck_conv1 = getSAMNeckConv1(meta, weights); + auto w_sam_neck_ln1 = getSAMNeckLN1(meta, weights); + auto w_sam_neck_conv2 = getSAMNeckConv2(meta, weights); + auto w_sam_neck_ln2 = getSAMNeckLN2(meta, weights); + + // CLIP-L + auto w_clip_patch_embed = getCLIPPatchEmbed(meta, weights); + auto w_clip_patch_embed_bias = getCLIPPatchEmbedBias(meta, weights); + auto w_clip_position_embed = getCLIPPositionEmbed(meta, weights); + auto w_clip_pre_layernorm = getCLIPPreLayerNorm(meta, weights); + + std::vector> w_clip_block_ln1, w_clip_block_attn_qkv, w_clip_block_attn_proj; + std::vector> w_clip_block_ln2, w_clip_block_mlp_fc1, w_clip_block_mlp_fc2; + for (size_t layer = 0; layer < 24; layer++) { + w_clip_block_ln1.push_back(getCLIPBlockLN1(meta, weights, layer)); + w_clip_block_attn_qkv.push_back(getCLIPBlockAttnQKV(meta, weights, layer)); + w_clip_block_attn_proj.push_back(getCLIPBlockAttnProj(meta, weights, layer)); + w_clip_block_ln2.push_back(getCLIPBlockLN2(meta, weights, layer)); + w_clip_block_mlp_fc1.push_back(getCLIPBlockMLPFC1(meta, weights, layer)); + w_clip_block_mlp_fc2.push_back(getCLIPBlockMLPFC2(meta, weights, layer)); + } + + // Projector + auto w_projector = getProjector(meta, weights); + auto w_image_newline = getImageNewline(meta, weights); + auto w_view_seperator = getViewSeperator(meta, weights); + + auto memory_pool = std::make_shared(256 * 1024 * 1024); // 256MB + + *rsrc = DeepSeekOCRDeviceResource{ + device, + dev_id, + handle, + w_in_embd, + w_out_norm, + w_out_embd, + sin_table, + cos_table, + w_attn_norm, + w_attn_q, + w_attn_k, + w_attn_v, + w_attn_o, + w_ffn_norm, + w_dense_gate, + w_dense_up, + w_dense_down, + w_moe_gate_weight, + w_moe_gate_bias, + w_moe_shared_gate, + w_moe_shared_up, + w_moe_shared_down, + w_moe_experts_gate, + w_moe_experts_up, + w_moe_experts_down, + w_sam_patch_embed, + w_sam_patch_embed_bias, + w_sam_block_norm1, + w_sam_block_attn_qkv, + w_sam_block_attn_proj, + w_sam_block_norm2, + w_sam_block_mlp_fc1, + w_sam_block_mlp_fc2, + w_sam_neck_conv1, + w_sam_neck_ln1, + w_sam_neck_conv2, + w_sam_neck_ln2, + w_clip_patch_embed, + w_clip_patch_embed_bias, + w_clip_position_embed, + w_clip_pre_layernorm, + w_clip_block_ln1, + w_clip_block_attn_qkv, + w_clip_block_attn_proj, + w_clip_block_ln2, + w_clip_block_mlp_fc1, + w_clip_block_mlp_fc2, + w_projector, + w_image_newline, + w_view_seperator, + stream, + comm, + memory_pool, + }; + + RUN_INFINI(infinirtDeviceSynchronize()); +} + +void releaseDeviceResource(DeepSeekOCRDeviceResource &res) { + infinirtDeviceSynchronize(); + + // 释放所有tensor + res.w_in_embd.reset(); + res.w_out_norm.reset(); + res.w_out_embd.reset(); + res.sin_table.reset(); + res.cos_table.reset(); + + for (auto &t : res.w_attn_norm) { + t.reset(); + } + res.w_attn_norm.clear(); + for (auto &t : res.w_attn_q) { + t.reset(); + } + res.w_attn_q.clear(); + for (auto &t : res.w_attn_k) { + t.reset(); + } + res.w_attn_k.clear(); + for (auto &t : res.w_attn_v) { + t.reset(); + } + res.w_attn_v.clear(); + for (auto &t : res.w_attn_o) { + t.reset(); + } + res.w_attn_o.clear(); + for (auto &t : res.w_ffn_norm) { + t.reset(); + } + res.w_ffn_norm.clear(); + + res.w_dense_gate.reset(); + res.w_dense_up.reset(); + res.w_dense_down.reset(); + + for (auto &t : res.w_moe_gate_weight) { + t.reset(); + } + res.w_moe_gate_weight.clear(); + for (auto &t : res.w_moe_gate_bias) { + t.reset(); + } + res.w_moe_gate_bias.clear(); + for (auto &t : res.w_moe_shared_gate) { + t.reset(); + } + res.w_moe_shared_gate.clear(); + for (auto &t : res.w_moe_shared_up) { + t.reset(); + } + res.w_moe_shared_up.clear(); + for (auto &t : res.w_moe_shared_down) { + t.reset(); + } + res.w_moe_shared_down.clear(); + + for (auto &experts : res.w_moe_experts_gate) { + for (auto &t : experts) { + t.reset(); + } + experts.clear(); + } + res.w_moe_experts_gate.clear(); + for (auto &experts : res.w_moe_experts_up) { + for (auto &t : experts) { + t.reset(); + } + experts.clear(); + } + res.w_moe_experts_up.clear(); + for (auto &experts : res.w_moe_experts_down) { + for (auto &t : experts) { + t.reset(); + } + experts.clear(); + } + res.w_moe_experts_down.clear(); + + // 释放视觉权重 - SAM + res.w_sam_patch_embed.reset(); + res.w_sam_patch_embed_bias.reset(); + for (auto &t : res.w_sam_block_norm1) { + t.reset(); + } + res.w_sam_block_norm1.clear(); + for (auto &t : res.w_sam_block_attn_qkv) { + t.reset(); + } + res.w_sam_block_attn_qkv.clear(); + for (auto &t : res.w_sam_block_attn_proj) { + t.reset(); + } + res.w_sam_block_attn_proj.clear(); + for (auto &t : res.w_sam_block_norm2) { + t.reset(); + } + res.w_sam_block_norm2.clear(); + for (auto &t : res.w_sam_block_mlp_fc1) { + t.reset(); + } + res.w_sam_block_mlp_fc1.clear(); + for (auto &t : res.w_sam_block_mlp_fc2) { + t.reset(); + } + res.w_sam_block_mlp_fc2.clear(); + + res.w_sam_neck_conv1.reset(); + res.w_sam_neck_ln1.reset(); + res.w_sam_neck_conv2.reset(); + res.w_sam_neck_ln2.reset(); + + // 释放视觉权重 - CLIP + res.w_clip_patch_embed.reset(); + res.w_clip_patch_embed_bias.reset(); + res.w_clip_position_embed.reset(); + res.w_clip_pre_layernorm.reset(); + for (auto &t : res.w_clip_block_ln1) { + t.reset(); + } + res.w_clip_block_ln1.clear(); + for (auto &t : res.w_clip_block_attn_qkv) { + t.reset(); + } + res.w_clip_block_attn_qkv.clear(); + for (auto &t : res.w_clip_block_attn_proj) { + t.reset(); + } + res.w_clip_block_attn_proj.clear(); + for (auto &t : res.w_clip_block_ln2) { + t.reset(); + } + res.w_clip_block_ln2.clear(); + for (auto &t : res.w_clip_block_mlp_fc1) { + t.reset(); + } + res.w_clip_block_mlp_fc1.clear(); + for (auto &t : res.w_clip_block_mlp_fc2) { + t.reset(); + } + res.w_clip_block_mlp_fc2.clear(); + + // Projector + res.w_projector.reset(); + res.w_image_newline.reset(); + res.w_view_seperator.reset(); + + infiniopDestroyHandle(res.handle); + res.handle = nullptr; + infinirtStreamDestroy(res.stream); + res.stream = nullptr; + infinicclCommDestroy(res.comm); + res.comm = nullptr; +} + +// ================ 视觉编码器推理 ================ +// 注意: 这里使用类似LLM算子的模式调用视觉算子 +// Conv2d/LayerNorm/GELU需要通过InferenceContext或直接API调用 + +std::shared_ptr inferVisionSAM(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + std::shared_ptr pixel_values) { + // SAM ViT-B推理: [batch, 3, H, W] -> [batch, num_patches, 1024] + auto dt_logits = meta.dt_logits; + auto stream = rsrc.stream; + auto handle = rsrc.handle; + auto batch = pixel_values->shape()[0]; + auto num_patches = (pixel_values->shape()[2] / 16) * (pixel_values->shape()[3] / 16); + + // 设置推理上下文 + auto cache_manager = new CacheManager(); + InferenceContext ctx(handle, rsrc.memory_pool, cache_manager, stream); + setInferenceContext(&ctx); + + // 1. Patch Embedding: Conv2d(3->768, kernel=16, stride=16) + auto H = pixel_values->shape()[2]; + auto W = pixel_values->shape()[3]; + auto patch_h = H / 16; + auto patch_w = W / 16; + + // Conv2d输出: [batch, 768, patch_h, patch_w] + auto conv_out = Tensor::buffer(dt_logits, {batch, 768, patch_h, patch_w}, rsrc.memory_pool); + conv2d(conv_out, pixel_values, rsrc.w_sam_patch_embed, rsrc.w_sam_patch_embed_bias, + 16, 16, 0, 0); // stride=16, padding=0 + + // Flatten(2).transpose(1,2): [batch, 768, h, w] -> [batch, h*w, 768] + auto patch_embeds = Tensor::buffer(dt_logits, {batch, num_patches, 768}, rsrc.memory_pool); + for (size_t b = 0; b < batch; b++) { + for (size_t h = 0; h < patch_h; h++) { + for (size_t w = 0; w < patch_w; w++) { + size_t patch_idx = h * patch_w + w; + for (size_t c = 0; c < 768; c++) { + size_t src_idx = ((b * 768 + c) * patch_h + h) * patch_w + w; + size_t dst_idx = (b * num_patches + patch_idx) * 768 + c; + RUN_INFINI(infinirtMemcpyAsync( + patch_embeds->data(dst_idx), + conv_out->data(src_idx), + dsize(dt_logits), INFINIRT_MEMCPY_D2D, stream)); + } + } + } + } + + // 2. ViT Transformer Blocks (12层) + auto hidden_states = patch_embeds; + for (size_t layer = 0; layer < 12; layer++) { + // 2.1 LayerNorm1 + residual connection + auto normed1 = Tensor::buffer(dt_logits, hidden_states->shape(), rsrc.memory_pool); + layer_norm(normed1, hidden_states, rsrc.w_sam_block_norm1[layer], meta.epsilon); + + // 2.2 Self-Attention: QKV -> Split -> Attention -> Proj + // SAM ViT-B: 12 heads, 768 dim, head_dim = 64 + const size_t num_heads = 12; + const size_t head_dim = 64; + + auto qkv_flat = Tensor::buffer(dt_logits, {batch * num_patches, 768 * 3}, rsrc.memory_pool); + linear(qkv_flat, normed1->view({batch * num_patches, 768}), + rsrc.w_sam_block_attn_qkv[layer], 1.0, 0.0, nullptr, nullptr); + + // Split QKV: [batch, num_patches, 768*3] -> [batch, num_patches, 3, num_heads, head_dim] + auto qkv = qkv_flat->view({batch, num_patches, 3, num_heads, head_dim}); + + // Extract Q, K, V: each [batch, num_patches, num_heads, head_dim] -> permute to [batch, num_heads, num_patches, head_dim] + auto q_buf = Tensor::buffer(dt_logits, {batch, num_heads, num_patches, head_dim}, rsrc.memory_pool); + auto k_buf = Tensor::buffer(dt_logits, {batch, num_heads, num_patches, head_dim}, rsrc.memory_pool); + auto v_buf = Tensor::buffer(dt_logits, {batch, num_heads, num_patches, head_dim}, rsrc.memory_pool); + + // Extract Q, K, V from QKV tensor (QKV layout: [batch, tokens, 3, heads, head_dim]) + auto q_extract = Tensor::buffer(dt_logits, {batch, num_patches, num_heads, head_dim}, rsrc.memory_pool); + auto k_extract = Tensor::buffer(dt_logits, {batch, num_patches, num_heads, head_dim}, rsrc.memory_pool); + auto v_extract = Tensor::buffer(dt_logits, {batch, num_patches, num_heads, head_dim}, rsrc.memory_pool); + + for (size_t b = 0; b < batch; b++) { + for (size_t t = 0; t < num_patches; t++) { + // Copy Q: qkv[b, t, 0, :, :] + RUN_INFINI(infinirtMemcpyAsync( + q_extract->data((b * num_patches + t) * num_heads * head_dim), + qkv_flat->data((b * num_patches + t) * 3 * num_heads * head_dim + 0 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + // Copy K: qkv[b, t, 1, :, :] + RUN_INFINI(infinirtMemcpyAsync( + k_extract->data((b * num_patches + t) * num_heads * head_dim), + qkv_flat->data((b * num_patches + t) * 3 * num_heads * head_dim + 1 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + // Copy V: qkv[b, t, 2, :, :] + RUN_INFINI(infinirtMemcpyAsync( + v_extract->data((b * num_patches + t) * num_heads * head_dim), + qkv_flat->data((b * num_patches + t) * 3 * num_heads * head_dim + 2 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + } + } + + // Permute from [batch, num_patches, num_heads, head_dim] to [batch, num_heads, num_patches, head_dim] + rearrange(q_buf, q_extract->view({batch, num_patches, num_heads, head_dim})->permute({0, 2, 1, 3})); + rearrange(k_buf, k_extract->view({batch, num_patches, num_heads, head_dim})->permute({0, 2, 1, 3})); + rearrange(v_buf, v_extract->view({batch, num_patches, num_heads, head_dim})->permute({0, 2, 1, 3})); + + // QK^T / sqrt(head_dim): [batch, num_heads, num_patches, num_patches] + auto qk_scores = Tensor::buffer(dt_logits, {batch * num_heads, num_patches, num_patches}, rsrc.memory_pool); + auto k_transposed = k_buf->view({batch * num_heads, num_patches, head_dim})->permute({0, 2, 1}); + linear(qk_scores, q_buf->view({batch * num_heads, num_patches, head_dim}), k_transposed, + 1.0f / sqrtf(head_dim), 0.0f, nullptr, nullptr); + + // Softmax over last dimension (non-causal for vision) + auto qk_softmax = qk_scores->view({batch * num_heads * num_patches, num_patches}); + causalSoftmax(qk_softmax, qk_softmax); // Note: 实际上应该是普通softmax + + // Attention @ V: [batch, num_heads, num_patches, head_dim] + auto attn_out_heads = Tensor::buffer(dt_logits, {batch * num_heads, num_patches, head_dim}, rsrc.memory_pool); + linear(attn_out_heads, qk_scores, v_buf->view({batch * num_heads, num_patches, head_dim}), + 1.0f, 0.0f, nullptr, nullptr); + + // Transpose and reshape: [batch, num_heads, num_patches, head_dim] -> [batch, num_patches, num_heads, head_dim] -> [batch, num_patches, 768] + auto attn_transposed = Tensor::buffer(dt_logits, {batch, num_patches, num_heads, head_dim}, rsrc.memory_pool); + rearrange(attn_transposed, attn_out_heads->view({batch, num_heads, num_patches, head_dim})->permute({0, 2, 1, 3})); + auto attn_out = attn_transposed->view({batch * num_patches, 768}); + + // Output projection with residual + auto attn_proj = Tensor::buffer(dt_logits, {batch, num_patches, 768}, rsrc.memory_pool); + linear(attn_proj->view({batch * num_patches, 768}), attn_out, + rsrc.w_sam_block_attn_proj[layer], 1.0, 0.0, + hidden_states->view({batch * num_patches, 768}), nullptr); + hidden_states = attn_proj; + + // 2.3 LayerNorm2 + residual connection + auto normed2 = Tensor::buffer(dt_logits, hidden_states->shape(), rsrc.memory_pool); + layer_norm(normed2, hidden_states, rsrc.w_sam_block_norm2[layer], meta.epsilon); + + // 2.4 MLP: FC1 -> GELU -> FC2 + auto mlp_hidden_flat = Tensor::buffer(dt_logits, {batch * num_patches, 3072}, rsrc.memory_pool); + linear(mlp_hidden_flat, normed2->view({batch * num_patches, 768}), + rsrc.w_sam_block_mlp_fc1[layer], 1.0, 0.0, nullptr, nullptr); + + gelu(mlp_hidden_flat, mlp_hidden_flat); + + auto mlp_out = Tensor::buffer(dt_logits, {batch, num_patches, 768}, rsrc.memory_pool); + linear(mlp_out->view({batch * num_patches, 768}), mlp_hidden_flat, + rsrc.w_sam_block_mlp_fc2[layer], 1.0, 0.0, + hidden_states->view({batch * num_patches, 768}), nullptr); + hidden_states = mlp_out; + } + + // 3. Neck网络: 将768维投影到1024维 + // Reshape回spatial format: [batch, num_patches, 768] -> [batch, 768, H/16, W/16] + auto hidden_spatial = Tensor::buffer(dt_logits, {batch, 768, patch_h, patch_w}, rsrc.memory_pool); + for (size_t b = 0; b < batch; b++) { + for (size_t h = 0; h < patch_h; h++) { + for (size_t w = 0; w < patch_w; w++) { + size_t patch_idx = h * patch_w + w; + for (size_t c = 0; c < 768; c++) { + size_t src_idx = (b * num_patches + patch_idx) * 768 + c; + size_t dst_idx = ((b * 768 + c) * patch_h + h) * patch_w + w; + RUN_INFINI(infinirtMemcpyAsync( + hidden_spatial->data(dst_idx), + hidden_states->data(src_idx), + dsize(dt_logits), INFINIRT_MEMCPY_D2D, stream)); + } + } + } + } + + // Neck Conv1x1: 768->1024 + auto neck1_spatial = Tensor::buffer(dt_logits, {batch, 1024, patch_h, patch_w}, rsrc.memory_pool); + conv2d(neck1_spatial, hidden_spatial, rsrc.w_sam_neck_conv1, nullptr, + 1, 1, 0, 0); // 1x1 conv, stride=1, padding=0 + + // Flatten back: [batch, 1024, h, w] -> [batch, h*w, 1024] + auto sam_features = Tensor::buffer(dt_logits, {batch, num_patches, 1024}, rsrc.memory_pool); + for (size_t b = 0; b < batch; b++) { + for (size_t h = 0; h < patch_h; h++) { + for (size_t w = 0; w < patch_w; w++) { + size_t patch_idx = h * patch_w + w; + for (size_t c = 0; c < 1024; c++) { + size_t src_idx = ((b * 1024 + c) * patch_h + h) * patch_w + w; + size_t dst_idx = (b * num_patches + patch_idx) * 1024 + c; + RUN_INFINI(infinirtMemcpyAsync( + sam_features->data(dst_idx), + neck1_spatial->data(src_idx), + dsize(dt_logits), INFINIRT_MEMCPY_D2D, stream)); + } + } + } + } + + delete cache_manager; + return sam_features; +} + +std::shared_ptr inferVisionCLIP(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + std::shared_ptr pixel_values, + std::shared_ptr sam_features) { + // CLIP-L推理: [batch, 3, H, W] -> [batch, num_patches, 1024] + auto dt_logits = meta.dt_logits; + auto stream = rsrc.stream; + auto handle = rsrc.handle; + auto batch = pixel_values->shape()[0]; + auto num_patches = (pixel_values->shape()[2] / 14) * (pixel_values->shape()[3] / 14); + + // 设置推理上下文 + auto cache_manager = new CacheManager(); + InferenceContext ctx(handle, rsrc.memory_pool, cache_manager, stream); + setInferenceContext(&ctx); + + // 1. Patch Embedding: Conv2d(3->1024, kernel=14, stride=14) + CLS token + auto H = pixel_values->shape()[2]; + auto W = pixel_values->shape()[3]; + auto patch_h = H / 14; + auto patch_w = W / 14; + + // Conv2d输出: [batch, 1024, patch_h, patch_w] + auto conv_out = Tensor::buffer(dt_logits, {batch, 1024, patch_h, patch_w}, rsrc.memory_pool); + conv2d(conv_out, pixel_values, rsrc.w_clip_patch_embed, rsrc.w_clip_patch_embed_bias, + 14, 14, 0, 0); // stride=14, padding=0 + + // Flatten和转置,并添加CLS token: [batch, num_patches+1, 1024] + auto patch_embeds = Tensor::buffer(dt_logits, {batch, num_patches + 1, 1024}, rsrc.memory_pool); + for (size_t b = 0; b < batch; b++) { + // CLS token at position 0 (初始化为0,实际应该是learnable parameter) + RUN_INFINI(infinirtMemsetAsync(patch_embeds->data(b * (num_patches + 1) * 1024), + 0, dsize(dt_logits) * 1024, stream)); + + // Patch tokens from position 1 + for (size_t h = 0; h < patch_h; h++) { + for (size_t w = 0; w < patch_w; w++) { + size_t patch_idx = h * patch_w + w; + for (size_t c = 0; c < 1024; c++) { + size_t src_idx = ((b * 1024 + c) * patch_h + h) * patch_w + w; + size_t dst_idx = (b * (num_patches + 1) + patch_idx + 1) * 1024 + c; + RUN_INFINI(infinirtMemcpyAsync( + patch_embeds->data(dst_idx), + conv_out->data(src_idx), + dsize(dt_logits), INFINIRT_MEMCPY_D2D, stream)); + } + } + } + } + + // 2. Add position embedding + add(patch_embeds, patch_embeds, rsrc.w_clip_position_embed); + + // 3. Pre LayerNorm + auto normed_embeds = Tensor::buffer(dt_logits, patch_embeds->shape(), rsrc.memory_pool); + layer_norm(normed_embeds, patch_embeds, rsrc.w_clip_pre_layernorm, meta.epsilon); + + // 4. CLIP Transformer Blocks (24层) + auto hidden_states = normed_embeds; + auto total_tokens = num_patches + 1; + + for (size_t layer = 0; layer < 24; layer++) { + // LayerNorm1 + Self-Attention + Residual + // CLIP-L: 16 heads, 1024 dim, head_dim = 64 + const size_t num_heads = 16; + const size_t head_dim = 64; + + auto ln1_out = Tensor::buffer(dt_logits, hidden_states->shape(), rsrc.memory_pool); + layer_norm(ln1_out, hidden_states, rsrc.w_clip_block_ln1[layer], meta.epsilon); + + auto qkv_flat = Tensor::buffer(dt_logits, {batch * total_tokens, 1024 * 3}, rsrc.memory_pool); + linear(qkv_flat, ln1_out->view({batch * total_tokens, 1024}), + rsrc.w_clip_block_attn_qkv[layer], 1.0, 0.0, nullptr, nullptr); + + // Split QKV: [batch, total_tokens, 1024*3] -> [batch, total_tokens, 3, num_heads, head_dim] + auto q_buf = Tensor::buffer(dt_logits, {batch, num_heads, total_tokens, head_dim}, rsrc.memory_pool); + auto k_buf = Tensor::buffer(dt_logits, {batch, num_heads, total_tokens, head_dim}, rsrc.memory_pool); + auto v_buf = Tensor::buffer(dt_logits, {batch, num_heads, total_tokens, head_dim}, rsrc.memory_pool); + + // Extract Q, K, V from QKV tensor (QKV layout: [batch, tokens, 3, heads, head_dim]) + auto q_extract = Tensor::buffer(dt_logits, {batch, total_tokens, num_heads, head_dim}, rsrc.memory_pool); + auto k_extract = Tensor::buffer(dt_logits, {batch, total_tokens, num_heads, head_dim}, rsrc.memory_pool); + auto v_extract = Tensor::buffer(dt_logits, {batch, total_tokens, num_heads, head_dim}, rsrc.memory_pool); + + for (size_t b = 0; b < batch; b++) { + for (size_t t = 0; t < total_tokens; t++) { + // Copy Q: qkv[b, t, 0, :, :] + RUN_INFINI(infinirtMemcpyAsync( + q_extract->data((b * total_tokens + t) * num_heads * head_dim), + qkv_flat->data((b * total_tokens + t) * 3 * num_heads * head_dim + 0 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + // Copy K: qkv[b, t, 1, :, :] + RUN_INFINI(infinirtMemcpyAsync( + k_extract->data((b * total_tokens + t) * num_heads * head_dim), + qkv_flat->data((b * total_tokens + t) * 3 * num_heads * head_dim + 1 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + // Copy V: qkv[b, t, 2, :, :] + RUN_INFINI(infinirtMemcpyAsync( + v_extract->data((b * total_tokens + t) * num_heads * head_dim), + qkv_flat->data((b * total_tokens + t) * 3 * num_heads * head_dim + 2 * num_heads * head_dim), + dsize(dt_logits) * num_heads * head_dim, + INFINIRT_MEMCPY_D2D, stream)); + } + } + + // Permute from [batch, total_tokens, num_heads, head_dim] to [batch, num_heads, total_tokens, head_dim] + rearrange(q_buf, q_extract->view({batch, total_tokens, num_heads, head_dim})->permute({0, 2, 1, 3})); + rearrange(k_buf, k_extract->view({batch, total_tokens, num_heads, head_dim})->permute({0, 2, 1, 3})); + rearrange(v_buf, v_extract->view({batch, total_tokens, num_heads, head_dim})->permute({0, 2, 1, 3})); + + // QK^T / sqrt(head_dim): [batch, num_heads, total_tokens, total_tokens] + auto qk_scores = Tensor::buffer(dt_logits, {batch * num_heads, total_tokens, total_tokens}, rsrc.memory_pool); + auto k_transposed = k_buf->view({batch * num_heads, total_tokens, head_dim})->permute({0, 2, 1}); + linear(qk_scores, q_buf->view({batch * num_heads, total_tokens, head_dim}), k_transposed, + 1.0f / sqrtf(head_dim), 0.0f, nullptr, nullptr); + + // Softmax over last dimension (non-causal for vision) + auto qk_softmax = qk_scores->view({batch * num_heads * total_tokens, total_tokens}); + causalSoftmax(qk_softmax, qk_softmax); // Note: 实际上应该是普通softmax + + // Attention @ V: [batch, num_heads, total_tokens, head_dim] + auto attn_out_heads = Tensor::buffer(dt_logits, {batch * num_heads, total_tokens, head_dim}, rsrc.memory_pool); + linear(attn_out_heads, qk_scores, v_buf->view({batch * num_heads, total_tokens, head_dim}), + 1.0f, 0.0f, nullptr, nullptr); + + // Transpose and reshape: [batch, num_heads, total_tokens, head_dim] -> [batch, total_tokens, num_heads, head_dim] -> [batch, total_tokens, 1024] + auto attn_transposed = Tensor::buffer(dt_logits, {batch, total_tokens, num_heads, head_dim}, rsrc.memory_pool); + rearrange(attn_transposed, attn_out_heads->view({batch, num_heads, total_tokens, head_dim})->permute({0, 2, 1, 3})); + auto attn_out = attn_transposed->view({batch * total_tokens, 1024}); + + // Output projection with residual + auto attn_proj = Tensor::buffer(dt_logits, {batch, total_tokens, 1024}, rsrc.memory_pool); + linear(attn_proj->view({batch * total_tokens, 1024}), attn_out, + rsrc.w_clip_block_attn_proj[layer], 1.0, 0.0, + hidden_states->view({batch * total_tokens, 1024}), nullptr); + hidden_states = attn_proj; + + // LayerNorm2 + MLP(GELU) + Residual + auto ln2_out = Tensor::buffer(dt_logits, hidden_states->shape(), rsrc.memory_pool); + layer_norm(ln2_out, hidden_states, rsrc.w_clip_block_ln2[layer], meta.epsilon); + + auto mlp_hidden_flat = Tensor::buffer(dt_logits, {batch * total_tokens, 4096}, rsrc.memory_pool); + linear(mlp_hidden_flat, ln2_out->view({batch * total_tokens, 1024}), + rsrc.w_clip_block_mlp_fc1[layer], 1.0, 0.0, nullptr, nullptr); + + gelu(mlp_hidden_flat, mlp_hidden_flat); + + auto mlp_out = Tensor::buffer(dt_logits, {batch, total_tokens, 1024}, rsrc.memory_pool); + linear(mlp_out->view({batch * total_tokens, 1024}), mlp_hidden_flat, + rsrc.w_clip_block_mlp_fc2[layer], 1.0, 0.0, + hidden_states->view({batch * total_tokens, 1024}), nullptr); + hidden_states = mlp_out; + } + + // 5. 移除CLS token (index 0),只保留patch tokens [1:] + auto clip_features = Tensor::buffer(dt_logits, {batch, num_patches, 1024}, rsrc.memory_pool); + for (size_t b = 0; b < batch; b++) { + RUN_INFINI(infinirtMemcpyAsync( + clip_features->data(b * num_patches * 1024), + hidden_states->data((b * total_tokens + 1) * 1024), + dsize(dt_logits) * num_patches * 1024, + INFINIRT_MEMCPY_D2D, stream)); + } + + delete cache_manager; + return clip_features; +} + +std::shared_ptr inferVision(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + const void *pixel_values_patches, + const void *pixel_values_global, + uint32_t num_patches, + uint32_t height_patches, + uint32_t width_patches) { + // 完整视觉编码: SAM+CLIP -> 拼接 -> Projector -> 添加分隔符 + auto dt_logits = meta.dt_logits; + auto d = meta.d; // 1280 + auto stream = rsrc.stream; + + std::vector> all_visual_embeds; + size_t total_visual_tokens = 0; + + // 1. 处理局部patches (如果有) + if (num_patches > 0 && pixel_values_patches != nullptr) { + // 创建tensor wrapper + auto patches_tensor = Tensor::weight(pixel_values_patches, dt_logits, + {num_patches, 3, 640, 640}); + + // SAM特征提取: [num_patches, N_sam, 1024] + auto sam_local = inferVisionSAM(meta, rsrc, patches_tensor); + + // CLIP特征提取: [num_patches, N_clip, 1024] + auto clip_local = inferVisionCLIP(meta, rsrc, patches_tensor, sam_local); + + // 拼接SAM和CLIP特征: [num_patches, N, 2048] + // CLIP和SAM的特征数量应该相同(都是(640/14)^2 = 2116个patch) + auto N = sam_local->shape()[1]; + auto concat_features = Tensor::buffer(dt_logits, {num_patches, N, 2048}, rsrc.memory_pool); + + // 拼接操作: concat_features = [clip_local, sam_local] along dim=-1 + for (size_t b = 0; b < num_patches; b++) { + for (size_t n = 0; n < N; n++) { + // 复制CLIP特征到前1024维 + RUN_INFINI(infinirtMemcpyAsync( + concat_features->data((b * N + n) * 2048), + clip_local->data((b * N + n) * 1024), + dsize(dt_logits) * 1024, INFINIRT_MEMCPY_D2D, stream)); + // 复制SAM特征到后1024维 + RUN_INFINI(infinirtMemcpyAsync( + concat_features->data((b * N + n) * 2048 + 1024), + sam_local->data((b * N + n) * 1024), + dsize(dt_logits) * 1024, INFINIRT_MEMCPY_D2D, stream)); + } + } + + // Projector投影: [num_patches, N, 2048] -> [num_patches, N, 1280] + auto local_features = Tensor::buffer(dt_logits, {num_patches * N, d}, rsrc.memory_pool); + linear(local_features, concat_features->view({num_patches * N, 2048}), + rsrc.w_projector, 1.0, 0.0, nullptr, nullptr); + + // 重排为2D网格并添加image_newline (每行patch后添加一个newline token) + // 计算:每个patch有N个token,加上每行的newline,总共num_patches个patch按grid排列 + total_visual_tokens += num_patches * N + height_patches * width_patches; + all_visual_embeds.push_back(local_features); + } + + // 2. 处理全局视图 + if (pixel_values_global != nullptr) { + auto global_tensor = Tensor::weight(pixel_values_global, dt_logits, {1, 3, 1024, 1024}); + + // SAM和CLIP特征提取 + auto sam_global = inferVisionSAM(meta, rsrc, global_tensor); // [1, N_global, 1024] + auto clip_global = inferVisionCLIP(meta, rsrc, global_tensor, sam_global); // [1, N_global, 1024] + + // 拼接特征 + auto N_global = sam_global->shape()[1]; // (1024/14)^2 = 5329 patches + auto concat_global = Tensor::buffer(dt_logits, {1, N_global, 2048}, rsrc.memory_pool); + + for (size_t n = 0; n < N_global; n++) { + RUN_INFINI(infinirtMemcpyAsync( + concat_global->data(n * 2048), + clip_global->data(n * 1024), + dsize(dt_logits) * 1024, INFINIRT_MEMCPY_D2D, stream)); + RUN_INFINI(infinirtMemcpyAsync( + concat_global->data(n * 2048 + 1024), + sam_global->data(n * 1024), + dsize(dt_logits) * 1024, INFINIRT_MEMCPY_D2D, stream)); + } + + // Projector投影 + auto global_features = Tensor::buffer(dt_logits, {N_global, d}, rsrc.memory_pool); + linear(global_features, concat_global->view({N_global, 2048}), + rsrc.w_projector, 1.0, 0.0, nullptr, nullptr); + + total_visual_tokens += N_global + 1; // +1 for image_newline at end + all_visual_embeds.push_back(global_features); + } + + // 3. 拼接所有视觉特征 + 添加分隔符 + // 最终结构: [local_patches...] + [image_newlines...] + [global_view] + [view_seperator] + auto visual_embeds = Tensor::buffer(dt_logits, {total_visual_tokens, d}, rsrc.memory_pool); + + size_t current_pos = 0; + + // 复制局部patches特征 (如果有) + for (auto &local_feat : all_visual_embeds) { + if (local_feat) { + size_t num_tokens = local_feat->shape()[0] * local_feat->shape()[1]; + RUN_INFINI(infinirtMemcpyAsync( + visual_embeds->data(current_pos * d), + local_feat->data(), + dsize(dt_logits) * num_tokens * d, + INFINIRT_MEMCPY_D2D, stream)); + current_pos += num_tokens; + + // 在每批patch后添加image_newline + if (current_pos < total_visual_tokens) { + RUN_INFINI(infinirtMemcpyAsync( + visual_embeds->data(current_pos * d), + rsrc.w_image_newline->data(), + dsize(dt_logits) * d, + INFINIRT_MEMCPY_D2D, stream)); + current_pos++; + } + } + } + + // 最后添加view_seperator + if (current_pos < total_visual_tokens) { + RUN_INFINI(infinirtMemcpyAsync( + visual_embeds->data(current_pos * d), + rsrc.w_view_seperator->data(), + dsize(dt_logits) * d, + INFINIRT_MEMCPY_D2D, stream)); + } + + fprintf(stderr, "inferVision: Framework complete, concat/projector need linear operator calls\n"); + return visual_embeds; +} + +// ================ LLM 推理 ================ + +void inferDeviceBatch(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + uint32_t idev, + uint32_t ndev, + const uint32_t *tokens, + uint32_t ntok, + const uint32_t *req_lens, + uint32_t nreq, + const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, + const uint32_t *topk, + const float *topp, + uint32_t *output, + void *last_logits, + const void *pixel_values_patches, + const void *pixel_values_global, + const uint32_t *patch_info) { + + auto nlayer = meta.n_dense_layer + meta.n_sparse_layer; + auto d = meta.d; + auto nh = meta.nh / ndev; + auto nkvh = meta.nkvh / ndev; + auto ngroup = nh / nkvh; + auto dh = meta.dh; + auto dt_logits = meta.dt_logits; + auto dvoc = meta.dvoc; + auto stream = rsrc.stream; + + // 设置推理上下文 + auto cache_manager = new CacheManager(); + InferenceContext ctx(rsrc.handle, rsrc.memory_pool, cache_manager, stream); + setInferenceContext(&ctx); + + // Allocate buffers + auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + + // Attention buffers + auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + auto q_buf = qkv_buf->slice(1, 0, nh * dh)->view({ntok, nh, dh}); + auto k_buf = qkv_buf->slice(1, nh * dh, nkvh * dh)->view({ntok, nkvh, dh}); + auto v_buf = qkv_buf->slice(1, (nh + nkvh) * dh, nkvh * dh)->view({ntok, nkvh, dh}); + auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + + // Sampling buffers + auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); + auto result_cpu = std::vector(nreq); + + // Prepare position IDs + auto batch_pos_ids = std::vector(ntok); + size_t req_start = 0; + for (uint32_t req = 0; req < nreq; req++) { + for (uint32_t i = 0; i < req_lens[req]; i++) { + batch_pos_ids[req_start + i] = req_pos[req] + i; + } + req_start += req_lens[req]; + } + + std::shared_ptr pos_ids_buf; + if (rsrc.device == INFINI_DEVICE_CPU) { + pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); + } else { + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), + sizeof(uint32_t) * ntok, INFINIRT_MEMCPY_H2D, stream)); + } + + // 1. 输入准备 - Token Embedding + Vision Encoding + std::shared_ptr visual_embeds; + bool has_vision_input = (pixel_values_global != nullptr); + const uint32_t IMAGE_TOKEN_ID = 128815; // DeepSeek-OCR的image token id + + if (has_vision_input && req_pos[0] == 0) { + // Prefill阶段且有图像输入,调用视觉编码器 + uint32_t num_patches = patch_info[0]; + uint32_t height_patches = patch_info[1]; + uint32_t width_patches = patch_info[2]; + + visual_embeds = inferVision(meta, rsrc, pixel_values_patches, pixel_values_global, + num_patches, height_patches, width_patches); + } + + // 构建输入embeddings,如果遇到image_token_id则替换为visual_embeds + size_t visual_token_idx = 0; + for (uint32_t i = 0; i < ntok; i++) { + if (has_vision_input && tokens[i] == IMAGE_TOKEN_ID && visual_embeds != nullptr) { + // 替换为视觉特征 + if (visual_token_idx < visual_embeds->shape()[0]) { + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + visual_embeds->data(visual_token_idx * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + visual_token_idx++; + } else { + // 视觉token用完了,使用普通embedding + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + rsrc.w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } + } else { + // 普通文本token + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + rsrc.w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } + } + + // Attention inner loop setup + size_t max_qk_size = 0; + size_t max_seq_len = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len)); + max_seq_len = std::max(max_seq_len, size_t(seq_len)); + } + + auto qk_buf = Tensor::buffer(dt_logits, {nh * max_qk_size}, rsrc.memory_pool); + auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); + auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); + + // 2. Transformer Decoder 循环 + for (size_t layer = 0; layer < nlayer; layer++) { + // 2.1 Attention + rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); + + // QKV投影 + linear(q_buf->view({ntok, nh * dh}), logits_out, rsrc.w_attn_q[layer], 1.0, 0.0, nullptr, nullptr); + linear(k_buf->view({ntok, nkvh * dh}), logits_out, rsrc.w_attn_k[layer], 1.0, 0.0, nullptr, nullptr); + linear(v_buf->view({ntok, nkvh * dh}), logits_out, rsrc.w_attn_v[layer], 1.0, 0.0, nullptr, nullptr); + + // RoPE + rope(q_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + + // Attention计算 (per request) + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = q_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = k_buf->slice({{0, token_offset, seq_len}}); + auto v = v_buf->slice({{0, token_offset, seq_len}}); + + // Update KV cache + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); + + // QK^T + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, + 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + + // Softmax + auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + + // Attention @ V + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + + // O投影 + linear(logits_in, o_buf, rsrc.w_attn_o[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); + + // AllReduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce(logits_in->data(), logits_in->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + + // 2.2 FFN + rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon); + + if (layer < meta.n_dense_layer) { + // Dense MLP (第0层) + auto di_dense = meta.di_dense / ndev; + auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di_dense}, rsrc.memory_pool); + auto gate_buf = gate_up_buf->slice(1, 0, di_dense); + auto up_buf = gate_up_buf->slice(1, di_dense, di_dense); + + linear(gate_buf, logits_out, rsrc.w_dense_gate, 1.0, 0.0, nullptr, nullptr); + linear(up_buf, logits_out, rsrc.w_dense_up, 1.0, 0.0, nullptr, nullptr); + swiglu(gate_buf, up_buf, gate_buf); + linear(logits_in, gate_buf, rsrc.w_dense_down, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); + + } else { + // MoE (第1-11层) + size_t moe_layer_idx = layer - meta.n_dense_layer; + auto di_moe = meta.di_moe; + auto di_shared = meta.di_shared / ndev; + + // Gate routing + auto router_logits = Tensor::buffer(dt_logits, {ntok, meta.nexperts}, rsrc.memory_pool); + linear(router_logits, logits_out, rsrc.w_moe_gate_weight[moe_layer_idx], + 1.0, 0.0, nullptr, rsrc.w_moe_gate_bias[moe_layer_idx]); + + // Top-K selection + auto topk_values = Tensor::buffer(INFINI_DTYPE_F32, {ntok, meta.kexperts}, rsrc.memory_pool); + auto topk_indices = Tensor::buffer(INFINI_DTYPE_I32, {ntok, meta.kexperts}, rsrc.memory_pool); + topkrouter(topk_values, topk_indices, router_logits, + rsrc.w_moe_gate_bias[moe_layer_idx], meta.routed_scale, meta.kexperts); + + // Shared experts (always active) + auto shared_gate_up = Tensor::buffer(dt_logits, {ntok, 2 * di_shared}, rsrc.memory_pool); + auto shared_gate = shared_gate_up->slice(1, 0, di_shared); + auto shared_up = shared_gate_up->slice(1, di_shared, di_shared); + auto shared_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + + linear(shared_gate, logits_out, rsrc.w_moe_shared_gate[moe_layer_idx], 1.0, 0.0, nullptr, nullptr); + linear(shared_up, logits_out, rsrc.w_moe_shared_up[moe_layer_idx], 1.0, 0.0, nullptr, nullptr); + swiglu(shared_gate, shared_up, shared_gate); + linear(shared_out, shared_gate, rsrc.w_moe_shared_down[moe_layer_idx], 1.0, 0.0, nullptr, nullptr); + + // Routed experts (需要动态调度和并行计算) + // 简化处理:只使用shared experts的输出 + linear(logits_in, shared_out, nullptr, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); + } + + // AllReduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce(logits_in->data(), logits_in->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + } + + // 3. Output & Sampling (only rank 0) + if (idev == 0) { + if (last_logits != nullptr) { + // Forward mode: return all logits + rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); + auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), + dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + } + + if (output != nullptr) { + // Inference mode: sample next token + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + token_offset += seq_len; + rmsnorm(logits_out->slice(0, req, 1), + logits_in->slice(0, token_offset - 1, 1), + rsrc.w_out_norm, meta.epsilon); + } + linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + + // Sampling + std::random_device _rd; + std::mt19937 gen(_rd()); + std::uniform_real_distribution dis(0.0, 1.0); + + for (uint32_t req = 0; req < nreq; req++) { + float random_val = dis(gen); + randomSample(result_buf->slice(0, req, 1), + prob_buf->slice(0, req, 1), + random_val, topp[req], topk[req], temperature[req]); + } + + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), + sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); + for (uint32_t req = 0; req < nreq; req++) { + output[req] = static_cast(result_cpu[req]); + } + } + } + + delete cache_manager; +} + +// ================ C API ================ + +__C __export struct DeepSeekOCRModel * +createDeepSeekOCRModel(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + auto model = new DeepSeekOCRModel(); + model->meta = *meta; + + std::vector comms; + if (ndev > 1) { + comms.resize(ndev); + infinicclCommInitAll(comms.data(), ndev, dev_ids); + } else { + comms.resize(1); + comms[0] = nullptr; + } + + model->resources.resize(ndev); + std::vector threads; + for (int i = 0; i < ndev; i++) { + threads.emplace_back([&, i]() { + createDeviceResource(&model->resources[i], meta, weights, + device, i, ndev, dev_ids[i], comms[i]); + }); + } + for (auto &t : threads) { + t.join(); + } + + return model; +} + +__C __export void +destroyDeepSeekOCRModel(struct DeepSeekOCRModel *model) { + if (model == nullptr) { + return; + } + + for (auto &res : model->resources) { + releaseDeviceResource(res); + } + delete model; +} + +__C __export void +inferBatchDeepSeekOCR(struct DeepSeekOCRModel *model, + const uint32_t *tokens, + uint32_t ntok, + const uint32_t *req_lens, + uint32_t nreq, + const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, + const uint32_t *topk, + const float *topp, + uint32_t *output) { + int ndev = model->resources.size(); + std::vector threads; + + for (int i = 0; i < ndev; i++) { + threads.emplace_back([&, i]() { + inferDeviceBatch(model->meta, model->resources[i], + i, ndev, tokens, ntok, + req_lens, nreq, req_pos, + kv_caches, temperature, topk, topp, + output, nullptr, + nullptr, nullptr, nullptr); + }); + } + + for (auto &t : threads) { + t.join(); + } +} + +__C __export void +forwardBatchDeepSeekOCR(struct DeepSeekOCRModel *model, + const uint32_t *tokens, + uint32_t ntok, + const uint32_t *req_lens, + uint32_t nreq, + const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits) { + int ndev = model->resources.size(); + std::vector threads; + + for (int i = 0; i < ndev; i++) { + threads.emplace_back([&, i]() { + inferDeviceBatch(model->meta, model->resources[i], + i, ndev, tokens, ntok, + req_lens, nreq, req_pos, + kv_caches, nullptr, nullptr, nullptr, + nullptr, logits, + nullptr, nullptr, nullptr); + }); + } + + for (auto &t : threads) { + t.join(); + } +} + +__C __export void +inferBatchDeepSeekOCRWithEmbeds(struct DeepSeekOCRModel *model, + const void *inputs_embeds, + uint32_t ntok, + const uint32_t *req_lens, + uint32_t nreq, + const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, + const uint32_t *topk, + const float *topp, + uint32_t *output) { + // 使用预计算的embeddings推理(用于Python端传入已处理的视觉特征) + // 这个接口主要用于灵活的多模态输入处理 + fprintf(stderr, "inferBatchDeepSeekOCRWithEmbeds not fully implemented yet\n"); +} diff --git a/src/models/deepseek_ocr/deepseek_ocr_impl.hpp b/src/models/deepseek_ocr/deepseek_ocr_impl.hpp new file mode 100644 index 00000000..7794df0b --- /dev/null +++ b/src/models/deepseek_ocr/deepseek_ocr_impl.hpp @@ -0,0 +1,157 @@ +#ifndef DEEPSEEK_OCR_IMPL_HPP +#define DEEPSEEK_OCR_IMPL_HPP + +#include "../../../allocator.hpp" +#include "../../../cache.hpp" +#include "../../../tensor.hpp" +#include "infinicore_infer/cache.h" +#include "infinicore_infer/models/deepseek_ocr.h" + +#include +#include +#include + +#include +#include + +// 设备资源结构 +struct DeepSeekOCRDeviceResource { + infiniDevice_t device; + int dev_id; + infiniopHandle_t handle; + + // 基础权重 + std::shared_ptr w_in_embd; + std::shared_ptr w_out_norm; + std::shared_ptr w_out_embd; + + // RoPE表 + std::shared_ptr sin_table; + std::shared_ptr cos_table; + + // Attention权重 (所有层) + std::vector> w_attn_norm; + std::vector> w_attn_q; + std::vector> w_attn_k; + std::vector> w_attn_v; + std::vector> w_attn_o; + + // FFN norm (所有层) + std::vector> w_ffn_norm; + + // Dense MLP权重 (第0层) + std::shared_ptr w_dense_gate; + std::shared_ptr w_dense_up; + std::shared_ptr w_dense_down; + + // MoE权重 (第1-11层) + std::vector> w_moe_gate_weight; + std::vector> w_moe_gate_bias; + + // Shared experts + std::vector> w_moe_shared_gate; + std::vector> w_moe_shared_up; + std::vector> w_moe_shared_down; + + // Routed experts (n_sparse_layer * nexperts) + std::vector>> w_moe_experts_gate; + std::vector>> w_moe_experts_up; + std::vector>> w_moe_experts_down; + + // Vision Encoder weights + // SAM ViT-B (12 layers) + std::shared_ptr w_sam_patch_embed; // Conv2d weight + std::shared_ptr w_sam_patch_embed_bias; + std::vector> w_sam_block_norm1; // 12 layers + std::vector> w_sam_block_attn_qkv; + std::vector> w_sam_block_attn_proj; + std::vector> w_sam_block_norm2; + std::vector> w_sam_block_mlp_fc1; + std::vector> w_sam_block_mlp_fc2; + std::shared_ptr w_sam_neck_conv1; // Neck conv layers + std::shared_ptr w_sam_neck_ln1; + std::shared_ptr w_sam_neck_conv2; + std::shared_ptr w_sam_neck_ln2; + + // CLIP-L (24 layers) + std::shared_ptr w_clip_patch_embed; // Conv2d weight + std::shared_ptr w_clip_patch_embed_bias; + std::shared_ptr w_clip_position_embed; + std::shared_ptr w_clip_pre_layernorm; + std::vector> w_clip_block_ln1; // 24 layers + std::vector> w_clip_block_attn_qkv; + std::vector> w_clip_block_attn_proj; + std::vector> w_clip_block_ln2; + std::vector> w_clip_block_mlp_fc1; + std::vector> w_clip_block_mlp_fc2; + + // Projector + std::shared_ptr w_projector; // [2048, 1280] + std::shared_ptr w_image_newline; // [1280] + std::shared_ptr w_view_seperator; // [1280] + + infinirtStream_t stream; + infinicclComm_t comm; + std::shared_ptr memory_pool; +}; + +// 模型结构 +struct DeepSeekOCRModel { + DeepSeekOCRMeta meta; + std::vector resources; +}; + +// 创建设备资源 +void createDeviceResource(DeepSeekOCRDeviceResource *rsrc, + const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + infiniDevice_t device, + int idev, + int ndev, + int dev_id, + infinicclComm_t comm); + +// 释放设备资源 +void releaseDeviceResource(DeepSeekOCRDeviceResource &res); + +// SAM ViT-B 视觉编码器推理 +std::shared_ptr inferVisionSAM(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + std::shared_ptr pixel_values); + +// CLIP-L 视觉编码器推理 +std::shared_ptr inferVisionCLIP(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + std::shared_ptr pixel_values, + std::shared_ptr sam_features); + +// 完整视觉编码(SAM + CLIP + Projector) +std::shared_ptr inferVision(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + const void *pixel_values_patches, + const void *pixel_values_global, + uint32_t num_patches, + uint32_t height_patches, + uint32_t width_patches); + +// 在单个设备上进行推理 +void inferDeviceBatch(const DeepSeekOCRMeta &meta, + DeepSeekOCRDeviceResource &rsrc, + uint32_t idev, + uint32_t ndev, + const uint32_t *tokens, + uint32_t ntok, + const uint32_t *req_lens, + uint32_t nreq, + const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, + const uint32_t *topk, + const float *topp, + uint32_t *output, + void *last_logits, + const void *pixel_values_patches, + const void *pixel_values_global, + const uint32_t *patch_info); + +#endif // DEEPSEEK_OCR_IMPL_HPP diff --git a/src/models/deepseek_ocr/deepseek_ocr_weight.hpp b/src/models/deepseek_ocr/deepseek_ocr_weight.hpp new file mode 100644 index 00000000..f07fe577 --- /dev/null +++ b/src/models/deepseek_ocr/deepseek_ocr_weight.hpp @@ -0,0 +1,454 @@ +#ifndef DEEPSEEK_OCR_WEIGHT_HPP +#define DEEPSEEK_OCR_WEIGHT_HPP + +#include "../../../tensor.hpp" +#include "infinicore_infer/models/deepseek_ocr.h" + +#include +#include + +// 获取输入embedding +inline std::shared_ptr getInEmbd(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->input_embd, weights->dt_mat, + {meta->dvoc, meta->d}, + weights->transpose_linear_weights); +} + +// 获取输出norm +inline std::shared_ptr getOutNorm(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->output_norm, weights->dt_norm, {meta->d}); +} + +// 获取输出embedding +inline std::shared_ptr getOutEmbd(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->output_embd, weights->dt_mat, + {meta->dvoc, meta->d}, + weights->transpose_linear_weights); +} + +// 获取RoPE sin表 +inline std::shared_ptr getSinTable(const DeepSeekOCRMeta *meta) { + size_t dctx = meta->dctx; + size_t dh = meta->dh; + float theta = meta->theta; + + std::vector sin_table(dctx * dh); + for (size_t pos = 0; pos < dctx; ++pos) { + for (size_t i = 0; i < dh; ++i) { + float freq = 1.0f / std::pow(theta, (2.0f * i) / dh); + sin_table[pos * dh + i] = std::sin(pos * freq); + } + } + return Tensor::weight(sin_table.data(), INFINI_DTYPE_F32, {dctx, dh}); +} + +// 获取RoPE cos表 +inline std::shared_ptr getCosTable(const DeepSeekOCRMeta *meta) { + size_t dctx = meta->dctx; + size_t dh = meta->dh; + float theta = meta->theta; + + std::vector cos_table(dctx * dh); + for (size_t pos = 0; pos < dctx; ++pos) { + for (size_t i = 0; i < dh; ++i) { + float freq = 1.0f / std::pow(theta, (2.0f * i) / dh); + cos_table[pos * dh + i] = std::cos(pos * freq); + } + } + return Tensor::weight(cos_table.data(), INFINI_DTYPE_F32, {dctx, dh}); +} + +// ===================== Attention Weights ===================== + +inline std::shared_ptr getAttnNorm(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->attn_norm[layer], weights->dt_norm, {meta->d}); +} + +inline std::shared_ptr getAttnQ(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer, + int idev, + int ndev) { + size_t d = meta->d; + size_t nh = meta->nh; + size_t dh = meta->dh; + size_t nh_per_dev = nh / ndev; + + return Tensor::weight(weights->attn_q[layer], weights->dt_mat, + {nh_per_dev * dh, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getAttnK(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer, + int idev, + int ndev) { + size_t d = meta->d; + size_t nkvh = meta->nkvh; + size_t dh = meta->dh; + size_t nkvh_per_dev = nkvh / ndev; + + return Tensor::weight(weights->attn_k[layer], weights->dt_mat, + {nkvh_per_dev * dh, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getAttnV(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer, + int idev, + int ndev) { + size_t d = meta->d; + size_t nkvh = meta->nkvh; + size_t dh = meta->dh; + size_t nkvh_per_dev = nkvh / ndev; + + return Tensor::weight(weights->attn_v[layer], weights->dt_mat, + {nkvh_per_dev * dh, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getAttnO(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer, + int idev, + int ndev) { + size_t d = meta->d; + size_t nh = meta->nh; + size_t dh = meta->dh; + size_t nh_per_dev = nh / ndev; + + return Tensor::weight(weights->attn_o[layer], weights->dt_mat, + {d, nh_per_dev * dh}, + weights->transpose_linear_weights); +} + +// ===================== FFN Weights ===================== + +inline std::shared_ptr getFFNNorm(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->ffn_norm[layer], weights->dt_norm, {meta->d}); +} + +// ===================== Dense MLP Weights (Layer 0) ===================== + +inline std::shared_ptr getDenseGate(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_dense = meta->di_dense; + size_t di_per_dev = di_dense / ndev; + + return Tensor::weight(weights->dense_gate, weights->dt_mat, + {di_per_dev, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getDenseUp(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_dense = meta->di_dense; + size_t di_per_dev = di_dense / ndev; + + return Tensor::weight(weights->dense_up, weights->dt_mat, + {di_per_dev, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getDenseDown(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_dense = meta->di_dense; + size_t di_per_dev = di_dense / ndev; + + return Tensor::weight(weights->dense_down, weights->dt_mat, + {d, di_per_dev}, + weights->transpose_linear_weights); +} + +// ===================== MoE Weights (Layer 1-11) ===================== + +inline std::shared_ptr getMoEGateWeight(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx) { + size_t d = meta->d; + size_t nexperts = meta->nexperts; + return Tensor::weight(weights->moe_gate_weight[sparse_layer_idx], + weights->dt_mat, + {nexperts, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getMoEGateBias(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx) { + size_t nexperts = meta->nexperts; + return Tensor::weight(weights->moe_gate_bias[sparse_layer_idx], + weights->dt_mat, + {nexperts}); +} + +// Shared experts +inline std::shared_ptr getMoESharedGate(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_shared = meta->di_shared; + size_t di_per_dev = di_shared / ndev; + + return Tensor::weight(weights->moe_shared_gate[sparse_layer_idx], + weights->dt_mat, + {di_per_dev, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getMoESharedUp(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_shared = meta->di_shared; + size_t di_per_dev = di_shared / ndev; + + return Tensor::weight(weights->moe_shared_up[sparse_layer_idx], + weights->dt_mat, + {di_per_dev, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getMoESharedDown(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_shared = meta->di_shared; + size_t di_per_dev = di_shared / ndev; + + return Tensor::weight(weights->moe_shared_down[sparse_layer_idx], + weights->dt_mat, + {d, di_per_dev}, + weights->transpose_linear_weights); +} + +// Routed experts +inline std::shared_ptr getMoEExpertsGate(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + size_t expert_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_moe = meta->di_moe; + + return Tensor::weight(weights->moe_experts_gate[sparse_layer_idx][expert_idx], + weights->dt_mat, + {di_moe, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getMoEExpertsUp(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + size_t expert_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_moe = meta->di_moe; + + return Tensor::weight(weights->moe_experts_up[sparse_layer_idx][expert_idx], + weights->dt_mat, + {di_moe, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getMoEExpertsDown(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t sparse_layer_idx, + size_t expert_idx, + int idev, + int ndev) { + size_t d = meta->d; + size_t di_moe = meta->di_moe; + + return Tensor::weight(weights->moe_experts_down[sparse_layer_idx][expert_idx], + weights->dt_mat, + {d, di_moe}, + weights->transpose_linear_weights); +} + +// ===================== Vision Weights ===================== + +// SAM ViT-B weights +inline std::shared_ptr getSAMPatchEmbed(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_patch_embed, weights->dt_mat, + {768, 3, 16, 16}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMPatchEmbedBias(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_patch_embed_bias, weights->dt_mat, {768}); +} + +inline std::shared_ptr getSAMBlockNorm1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_norm1[layer], weights->dt_norm, {768}); +} + +inline std::shared_ptr getSAMBlockAttnQKV(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_attn_qkv[layer], weights->dt_mat, + {768 * 3, 768}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMBlockAttnProj(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_attn_proj[layer], weights->dt_mat, + {768, 768}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMBlockNorm2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_norm2[layer], weights->dt_norm, {768}); +} + +inline std::shared_ptr getSAMBlockMLPFC1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_mlp_fc1[layer], weights->dt_mat, + {3072, 768}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMBlockMLPFC2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->sam_block_mlp_fc2[layer], weights->dt_mat, + {768, 3072}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMNeckConv1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_neck_conv1, weights->dt_mat, + {1024, 768, 1, 1}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMNeckLN1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_neck_ln1, weights->dt_norm, {1024}); +} + +inline std::shared_ptr getSAMNeckConv2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_neck_conv2, weights->dt_mat, + {1024, 1024, 3, 3}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getSAMNeckLN2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->sam_neck_ln2, weights->dt_norm, {1024}); +} + +// CLIP-L weights +inline std::shared_ptr getCLIPPatchEmbed(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->clip_patch_embed, weights->dt_mat, + {1024, 3, 14, 14}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getCLIPPatchEmbedBias(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->clip_patch_embed_bias, weights->dt_mat, {1024}); +} + +inline std::shared_ptr getCLIPPositionEmbed(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->clip_position_embed, weights->dt_mat, {257, 1024}); +} + +inline std::shared_ptr getCLIPPreLayerNorm(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + return Tensor::weight(weights->clip_pre_layernorm, weights->dt_norm, {1024}); +} + +inline std::shared_ptr getCLIPBlockLN1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_ln1[layer], weights->dt_norm, {1024}); +} + +inline std::shared_ptr getCLIPBlockAttnQKV(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_attn_qkv[layer], weights->dt_mat, + {1024 * 3, 1024}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getCLIPBlockAttnProj(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_attn_proj[layer], weights->dt_mat, + {1024, 1024}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getCLIPBlockLN2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_ln2[layer], weights->dt_norm, {1024}); +} + +inline std::shared_ptr getCLIPBlockMLPFC1(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_mlp_fc1[layer], weights->dt_mat, + {4096, 1024}, weights->transpose_linear_weights); +} + +inline std::shared_ptr getCLIPBlockMLPFC2(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights, + size_t layer) { + return Tensor::weight(weights->clip_block_mlp_fc2[layer], weights->dt_mat, + {1024, 4096}, weights->transpose_linear_weights); +} + +// Projector +inline std::shared_ptr getProjector(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + size_t d = meta->d; // 1280 + return Tensor::weight(weights->projector, weights->dt_mat, + {2048, d}, + weights->transpose_linear_weights); +} + +inline std::shared_ptr getImageNewline(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + size_t d = meta->d; + return Tensor::weight(weights->image_newline, weights->dt_mat, {d}); +} + +inline std::shared_ptr getViewSeperator(const DeepSeekOCRMeta *meta, + const DeepSeekOCRWeights *weights) { + size_t d = meta->d; + return Tensor::weight(weights->view_seperator, weights->dt_mat, {d}); +} + +#endif // DEEPSEEK_OCR_WEIGHT_HPP diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index db5fda11..6e68b653 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -56,6 +56,82 @@ void InferenceContext::rmsnorm(std::shared_ptr y, y->data(), x->data(), w->data(), stream)); } +void InferenceContext::layer_norm(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + float epsilon) { + size_t key = CacheManager::createDescriptorKey(y, x, w); + + infiniopLayerNormDescriptor_t desc; + if (!cache_manager->getLayerNormDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateLayerNormDescriptor( + op_handle, &desc, y->desc(), x->desc(), w->desc(), epsilon)); + cache_manager->putLayerNormDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetLayerNormWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopLayerNorm( + desc, workspace, workspace_size, + y->data(), x->data(), w->data(), stream)); +} + +void InferenceContext::gelu(std::shared_ptr y, + std::shared_ptr x) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopGELUDescriptor_t desc; + if (!cache_manager->getGELUDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateGELUDescriptor( + op_handle, &desc, y->desc(), x->desc())); + cache_manager->putGELUDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetGELUWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopGELU( + desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::conv2d(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr bias, + int stride_h, int stride_w, + int padding_h, int padding_w) { + size_t key = CacheManager::createDescriptorKey(y, x, w, bias); + hash_combine(key, std::hash()(stride_h)); + hash_combine(key, std::hash()(stride_w)); + hash_combine(key, std::hash()(padding_h)); + hash_combine(key, std::hash()(padding_w)); + + infiniopConv2dDescriptor_t desc; + if (!cache_manager->getConv2dDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateConv2dDescriptor( + op_handle, &desc, y->desc(), x->desc(), w->desc(), + bias ? bias->desc() : nullptr, + stride_h, stride_w, padding_h, padding_w)); + cache_manager->putConv2dDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetConv2dWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopConv2d( + desc, workspace, workspace_size, + y->data(), x->data(), w->data(), + bias ? bias->data() : nullptr, stream)); +} + void InferenceContext::gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..00d2a6be 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -23,6 +23,18 @@ struct InferenceContext { std::shared_ptr x, std::shared_ptr w, float epsilon); + void layer_norm(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + float epsilon); + void gelu(std::shared_ptr y, + std::shared_ptr x); + void conv2d(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr bias, + int stride_h, int stride_w, + int padding_h, int padding_w); void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, @@ -86,6 +98,21 @@ inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, getInferenceContext().rmsnorm(y, x, w, epsilon); } +inline void layer_norm(std::shared_ptr y, std::shared_ptr x, + std::shared_ptr w, float epsilon) { + getInferenceContext().layer_norm(y, x, w, epsilon); +} + +inline void gelu(std::shared_ptr y, std::shared_ptr x) { + getInferenceContext().gelu(y, x); +} + +inline void conv2d(std::shared_ptr y, std::shared_ptr x, + std::shared_ptr w, std::shared_ptr bias, + int stride_h, int stride_w, int padding_h, int padding_w) { + getInferenceContext().conv2d(y, x, w, bias, stride_h, stride_w, padding_h, padding_w); +} + inline void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta) { getInferenceContext().gemm(c, a, b, alpha, beta);