From bdb1848dc9893c1fe59203e7adaddf9525c5f85c Mon Sep 17 00:00:00 2001 From: hejianlin <892082223@qq.com> Date: Thu, 25 Dec 2025 22:08:44 +0800 Subject: [PATCH 1/3] qwen3vl text infer completed. --- include/infinicore_infer.h | 4 + include/infinicore_infer/models/qwen3vl.h | 205 +++++++ scripts/libinfinicore_infer/__init__.py | 19 + scripts/libinfinicore_infer/qwen3vl.py | 268 +++++++++ scripts/qwen3vl.py | 620 +++++++++++++++++++++ scripts/qwen3vl_test.py | 102 ++++ src/cache_manager/opcache_manager.hpp | 4 + src/models/deepseek_v3/deepseek_v3.cpp | 4 +- src/models/inference_context.cpp | 41 ++ src/models/inference_context.hpp | 13 + src/models/jiuge/jiuge.cpp | 2 + src/models/qwen3vl/qwen3vl.cpp | 415 ++++++++++++++ src/models/qwen3vl/qwen3vl_cache.cpp | 43 ++ src/models/qwen3vl/qwen3vl_impl.hpp | 130 +++++ src/models/qwen3vl/qwen3vl_weight.cpp | 634 ++++++++++++++++++++++ src/tensor/tensor.cpp | 2 +- 16 files changed, 2503 insertions(+), 3 deletions(-) create mode 100644 include/infinicore_infer/models/qwen3vl.h create mode 100644 scripts/libinfinicore_infer/qwen3vl.py create mode 100644 scripts/qwen3vl.py create mode 100644 scripts/qwen3vl_test.py create mode 100644 src/models/qwen3vl/qwen3vl.cpp create mode 100644 src/models/qwen3vl/qwen3vl_cache.cpp create mode 100644 src/models/qwen3vl/qwen3vl_impl.hpp create mode 100644 src/models/qwen3vl/qwen3vl_weight.cpp diff --git a/include/infinicore_infer.h b/include/infinicore_infer.h index 0bed7bc7..5b2ceb99 100644 --- a/include/infinicore_infer.h +++ b/include/infinicore_infer.h @@ -4,7 +4,11 @@ #include "infinicore_infer/cache.h" #include "infinicore_infer/weights_loader.h" + #include "infinicore_infer/models/deepseek.h" #include "infinicore_infer/models/jiuge.h" +#include "infinicore_infer/models/jiuge_awq.h" +#include "infinicore_infer/models/qwen3vl.h" + #endif /* INFINICORE_INFER_H */ diff --git a/include/infinicore_infer/models/qwen3vl.h b/include/infinicore_infer/models/qwen3vl.h new file mode 100644 index 00000000..ea8e6eee --- /dev/null +++ b/include/infinicore_infer/models/qwen3vl.h @@ -0,0 +1,205 @@ +#ifndef QWEN3VL_WEIGHTS_H +#define QWEN3VL_WEIGHTS_H + +#include +#include +#include + +#include +#include + +struct Qwen3vlWeights; + +// Function pointer signatures +typedef void (*qwen3vl_load_global_fn)(Qwen3vlWeights *, void *cpu_ptr); +typedef void (*qwen3vl_load_layer_fn)(Qwen3vlWeights *, void *cpu_ptr, size_t layer_id); +// Struct containing all weight loading functions +typedef struct { + // Global + qwen3vl_load_global_fn load_input_embd; + qwen3vl_load_global_fn load_output_norm; + qwen3vl_load_global_fn load_output_embd; + + // Attention + qwen3vl_load_layer_fn load_attn_norm; + qwen3vl_load_layer_fn load_attn_q_norm; + qwen3vl_load_layer_fn load_attn_k_norm; + qwen3vl_load_layer_fn load_attn_qkv_proj; + qwen3vl_load_layer_fn load_attn_o_proj; + + // MLP + qwen3vl_load_layer_fn load_mlp_norm; + qwen3vl_load_layer_fn load_mlp_gate_up; + qwen3vl_load_layer_fn load_mlp_down; + +} Qwen3vlLangWeightLoader; + +typedef struct { + // Patch_embed + qwen3vl_load_global_fn load_patch_embed_weight; + qwen3vl_load_global_fn load_patch_embed_bias; + qwen3vl_load_global_fn load_pos_embed_weight; + + // blocks attn + qwen3vl_load_layer_fn load_attn_proj_weight; + qwen3vl_load_layer_fn load_attn_proj_bias; + qwen3vl_load_layer_fn load_attn_qkv_weight; + qwen3vl_load_layer_fn load_attn_qkv_bias; + + //block mlp + qwen3vl_load_layer_fn load_mlp_linear_fc1_weight; + qwen3vl_load_layer_fn load_mlp_linear_fc1_bias; + qwen3vl_load_layer_fn load_mlp_linear_fc2_weight; + qwen3vl_load_layer_fn load_mlp_linear_fc2_bias; + + //block norm + qwen3vl_load_layer_fn load_norm1_weight; + qwen3vl_load_layer_fn load_norm1_bias; + qwen3vl_load_layer_fn load_norm2_weight; + qwen3vl_load_layer_fn load_norm2_bias; + + //deepstack_merger + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight; + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias; + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight; + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_bias; + qwen3vl_load_layer_fn load_deepstack_merger_norm_weight; + qwen3vl_load_layer_fn load_deepstack_merger_norm_bias; + + //merger + qwen3vl_load_global_fn load_merger_linear_fc1_weight; + qwen3vl_load_global_fn load_merger_linear_fc1_bias; + qwen3vl_load_global_fn load_merger_linear_fc2_weight; + qwen3vl_load_global_fn load_merger_linear_fc2_bias; + qwen3vl_load_global_fn load_merger_norm_weight; + qwen3vl_load_global_fn load_merger_norm_bias; + +} Qwen3vlVisWeightLoader; + +typedef struct { + Qwen3vlLangWeightLoader lang_loader; + Qwen3vlVisWeightLoader vis_loader; +} Qwen3vlWeightLoader; + +struct Qwen3vlModel; + +typedef struct { + size_t bos_token_id; + size_t eos_token_id; + size_t head_dim; + size_t hidden_size; + float initializer_range; + size_t intermediate_size; + size_t max_tokens; + size_t num_attention_heads; + size_t num_hidden_layers; + size_t num_key_value_heads; + float rms_norm_eps; + size_t mrope_section[3]; + size_t rope_theta; + size_t vocab_size; +} Qwen3vlTextMeta; + +typedef struct { + size_t depth; + size_t deepstack_visual_indexes[3]; + size_t hidden_size; + size_t in_channels; + float initializer_range; + size_t intermediate_size; + size_t num_heads; + size_t num_position_embeddings; + size_t out_hidden_size; + size_t patch_size; + size_t spatial_merge_size; + size_t temporal_patch_size; +} Qwen3vlVisMeta; + +typedef struct { + infiniDtype_t dtype; //INFINI_DTYPE_BF16 + + Qwen3vlTextMeta text_meta; + Qwen3vlVisMeta vis_meta; + + size_t image_token_id; + size_t video_token_id; + size_t vision_end_token_id; + size_t vision_start_token_id; +} Qwen3vlMeta; + +//////////////////// APIs /////////////////////// +/// @brief 创建模型 +/// @param device 协处理器种类 +/// @param ndev 协处理器数量 +/// @param dev_ids 协处理器编号,长度为 ndev +__C __export struct Qwen3vlModel * +createQwen3vlModel(const Qwen3vlMeta *, + const Qwen3vlWeights *); + +__C Qwen3vlWeights * +createQwen3vlWeights(const Qwen3vlMeta *meta, + infiniDevice_t device, + int ndev, + const int *dev_ids, + bool transpose_weight); + +__C __export Qwen3vlWeightLoader * +createQwen3vlWeightLoader(); + +/// @brief 销毁模型 +__C __export void destroyQwen3vlModel(struct Qwen3vlModel *); + +__C __export struct Qwen3vlCache * +createQwen3vlCache(const struct Qwen3vlModel *); + +__C __export void +dropQwen3vlCache(const struct Qwen3vlModel *, + struct Qwen3vlCache *); + +/// @brief 批次推理一轮,并采样出新的 token +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param temperature 采样温度(0. 表示贪心采样) +/// @param topk 采样 topk(1 表示贪心采样) +/// @param topp 采样 topp +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +inferBatchQwen3vl(struct Qwen3vlModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,输出 output embedding 后的 logits +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +forwardBatchQwen3vl(struct Qwen3vlModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **caches, + void *logits); + +#endif // QWEN3VL_WEIGHTS_H + +// self, +// input_ids: torch.LongTensor = None, +// attention_mask: Optional[torch.Tensor] = None, +// position_ids: Optional[torch.LongTensor] = None, +// past_key_values: Optional[Cache] = None, +// inputs_embeds: Optional[torch.FloatTensor] = None, +// pixel_values: Optional[torch.Tensor] = None, +// pixel_values_videos: Optional[torch.FloatTensor] = None, +// image_grid_thw: Optional[torch.LongTensor] = None, +// video_grid_thw: Optional[torch.LongTensor] = None, +// cache_position: Optional[torch.LongTensor] = None, \ No newline at end of file diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 8fc5f4db..3cd85f4f 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -8,6 +8,17 @@ DeepSeekV3WeightLoaderCStruct, DeepSeekV3CacheCStruct, ) +from .qwen3vl import ( + Qwen3vlModel, + Qwen3vlMetaCStruct, + TextMetaCStruct, + VisMetaCStruct, + Qwen3vlWeightsCStruct, + Qwen3vlWeightLoaderCStruct, + Qwen3vlVisWeightLoaderCStruct, + Qwen3vlLangWeightLoaderCStruct, + Qwen3vlCacheCStruct, +) __all__ = [ "DataType", @@ -23,5 +34,13 @@ "DeepSeekV3MetaCStruct", "DeepSeekV3WeightsCStruct", "DeepSeekV3WeightLoaderCStruct", + "Qwen3vlModel", + "Qwen3vlMetaCStruct", + "TextMetaCStruct", + "VisMetaCStruct", + "Qwen3vlWeightsCStruct", + "Qwen3vlWeightLoaderCStruct", + "Qwen3vlVisWeightLoaderCStruct", + "Qwen3vlLangWeightLoaderCStruct", "ModelRegister", ] diff --git a/scripts/libinfinicore_infer/qwen3vl.py b/scripts/libinfinicore_infer/qwen3vl.py new file mode 100644 index 00000000..949ba228 --- /dev/null +++ b/scripts/libinfinicore_infer/qwen3vl.py @@ -0,0 +1,268 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import ( + c_size_t, + c_uint, + c_int, + c_float, + c_void_p, + c_bool, + POINTER, + Structure, + CFUNCTYPE, +) + + +class TextMetaCStruct(Structure): + _fields_ = [ + ("bos_token_id", c_size_t), + ("eos_token_id", c_size_t), + ("head_dim", c_size_t), + ("hidden_size", c_size_t), + ("initializer_range", c_float), + ("intermediate_size", c_size_t), + ("max_tokens", c_size_t), + ("num_attention_heads", c_size_t), + ("num_hidden_layers", c_size_t), + ("num_key_value_heads", c_size_t), + ("rms_norm_eps", c_float), + ("mrope_section", c_size_t * 3), + ("rope_theta", c_size_t), + ("vocab_size", c_size_t), + ] + + +class VisMetaCStruct(Structure): + _fields_ = [ + ("depth", c_size_t), + ("deepstack_visual_indexes", c_size_t * 3), + ("hidden_size", c_size_t), + ("in_channels", c_size_t), + ("initializer_range", c_float), + ("intermediate_size", c_size_t), + ("num_heads", c_size_t), + ("num_position_embeddings", c_size_t), + ("out_hidden_size", c_size_t), + ("patch_size", c_size_t), + ("spatial_merge_size", c_size_t), + ("temporal_patch_size", c_size_t), + ] + + +class Qwen3vlMetaCStruct(Structure): + _fields_ = [ + ("dtype", DataType), + ("text_meta", TextMetaCStruct), + ("vis_meta", VisMetaCStruct), + # Token ids + ("image_token_id", c_size_t), + ("video_token_id", c_size_t), + ("vision_end_token_id", c_size_t), + ("vision_start_token_id", c_size_t), + ] + + +class Qwen3vlWeightsCStruct(Structure): + pass + + +class Qwen3vlModelCStruct(Structure): + pass + + +class Qwen3vlCacheCStruct(Structure): + pass + + +load_global_fn = CFUNCTYPE(None, POINTER(Qwen3vlWeightsCStruct), c_void_p) +load_layer_fn = CFUNCTYPE(None, POINTER(Qwen3vlWeightsCStruct), c_void_p, c_size_t) + + +class Qwen3vlLangWeightLoaderCStruct(Structure): + _fields_ = [ + # Global + ("load_input_embd", load_global_fn), + ("load_output_norm", load_global_fn), + ("load_output_embd", load_global_fn), + # Attention + ("load_attn_norm", load_layer_fn), + ("load_attn_q_norm", load_layer_fn), + ("load_attn_k_norm", load_layer_fn), + ("load_attn_qkv_proj", load_layer_fn), + ("load_attn_o_proj", load_layer_fn), + # MLP + ("load_mlp_norm", load_layer_fn), + ("load_mlp_gate_up", load_layer_fn), + ("load_mlp_down", load_layer_fn), + ] + + +class Qwen3vlVisWeightLoaderCStruct(Structure): + _fields_ = [ + # Patch embed + ("load_patch_embed_weight", load_global_fn), + ("load_patch_embed_bias", load_global_fn), + ("load_pos_embed_weight", load_global_fn), + # Blocks attention + ("load_attn_proj_weight", load_layer_fn), + ("load_attn_proj_bias", load_layer_fn), + ("load_attn_qkv_weight", load_layer_fn), + ("load_attn_qkv_bias", load_layer_fn), + # Blocks MLP + ("load_mlp_linear_fc1_weight", load_layer_fn), + ("load_mlp_linear_fc1_bias", load_layer_fn), + ("load_mlp_linear_fc2_weight", load_layer_fn), + ("load_mlp_linear_fc2_bias", load_layer_fn), + # Blocks norm + ("load_norm1_weight", load_layer_fn), + ("load_norm1_bias", load_layer_fn), + ("load_norm2_weight", load_layer_fn), + ("load_norm2_bias", load_layer_fn), + # Deepstack merger + ("load_deepstack_merger_linear_fc1_weight", load_layer_fn), + ("load_deepstack_merger_linear_fc1_bias", load_layer_fn), + ("load_deepstack_merger_linear_fc2_weight", load_layer_fn), + ("load_deepstack_merger_linear_fc2_bias", load_layer_fn), + ("load_deepstack_merger_norm_weight", load_layer_fn), + ("load_deepstack_merger_norm_bias", load_layer_fn), + # Merger + ("load_merger_linear_fc1_weight", load_global_fn), + ("load_merger_linear_fc1_bias", load_global_fn), + ("load_merger_linear_fc2_weight", load_global_fn), + ("load_merger_linear_fc2_bias", load_global_fn), + ("load_merger_norm_weight", load_global_fn), + ("load_merger_norm_bias", load_global_fn), + ] + + +class Qwen3vlWeightLoaderCStruct(Structure): + _fields_ = [ + ("lang_loader", Qwen3vlLangWeightLoaderCStruct), + ("vis_loader", Qwen3vlVisWeightLoaderCStruct), + ] + + +@register_model +class Qwen3vlModel(BaseModel): + @classmethod + def register_lib(cls, lib): + """Register Qwen3vl model functions with the library""" + lib.createQwen3vlWeightLoader.argtypes = [] + lib.createQwen3vlWeightLoader.restype = POINTER(Qwen3vlWeightLoaderCStruct) + + lib.createQwen3vlWeights.argtypes = [ + POINTER(Qwen3vlMetaCStruct), + DeviceType, + c_int, + POINTER(c_int), + c_bool, + ] + lib.createQwen3vlWeights.restype = POINTER(Qwen3vlWeightsCStruct) + + lib.createQwen3vlModel.argtypes = [ + POINTER(Qwen3vlMetaCStruct), + POINTER(Qwen3vlWeightsCStruct), + ] + lib.createQwen3vlModel.restype = POINTER(Qwen3vlModelCStruct) + + lib.destroyQwen3vlModel.argtypes = [POINTER(Qwen3vlModelCStruct)] + + lib.createQwen3vlCache.argtypes = [POINTER(Qwen3vlModelCStruct)] + lib.createQwen3vlCache.restype = POINTER(Qwen3vlCacheCStruct) + + lib.dropQwen3vlCache.argtypes = [ + POINTER(Qwen3vlModelCStruct), + POINTER(Qwen3vlCacheCStruct), + ] + + lib.inferBatchQwen3vl.argtypes = [ + POINTER(Qwen3vlModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(Qwen3vlCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.forwardBatchQwen3vl.argtypes = [ + POINTER(Qwen3vlModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(Qwen3vlCacheCStruct)), + c_void_p, + ] + + def create_weight_loader(self): + return self.lib.createQwen3vlWeightLoader() + + def create_weights(self, meta, device_type, ndev, dev_ids, transpose_weight): + return self.lib.createQwen3vlWeights(meta, device_type, ndev, dev_ids, transpose_weight) + + def create_model(self, meta, weights): + return self.lib.createQwen3vlModel(meta, weights) + + def destroy_model(self, model): + self.lib.destroyQwen3vlModel(model) + + def create_cache(self, model): + return self.lib.createQwen3vlCache(model) + + def drop_cache(self, model, cache): + self.lib.dropQwen3vlCache(model, cache) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchQwen3vl( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + temperature, + topk, + topp, + output, + ) + + def forward_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + logits, + ): + self.lib.forwardBatchQwen3vl( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + logits, + ) \ No newline at end of file diff --git a/scripts/qwen3vl.py b/scripts/qwen3vl.py new file mode 100644 index 00000000..12cc2d72 --- /dev/null +++ b/scripts/qwen3vl.py @@ -0,0 +1,620 @@ +import ctypes +from typing import List, Sequence + +from tqdm import tqdm + +from libinfinicore_infer import ( + Qwen3vlModel, + Qwen3vlMetaCStruct, + TextMetaCStruct, + VisMetaCStruct, + Qwen3vlWeightsCStruct, + Qwen3vlCacheCStruct, + DataType, + DeviceType, +) +from infer_task import InferTask, KVCache + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref, c_bool +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import math +import torch +import transformers + +torch.set_default_device("cpu") + + +class Qwen3vlLangWeightsNaming: + def input_embd(self): + return "model.language_model.embed_tokens.weight" + + def output_embd(self): + return "model.language_model.embed_tokens.weight" + + def output_norm(self): + return "model.language_model.norm.weight" + + def attn_norm(self, i): + return f"model.language_model.layers.{i}.input_layernorm.weight" + + def attn_q_proj(self, i): + return f"model.language_model.layers.{i}.self_attn.q_proj.weight" + + def attn_q_norm(self, i): + return f"model.language_model.layers.{i}.self_attn.q_norm.weight" + + def attn_k_proj(self, i): + return f"model.language_model.layers.{i}.self_attn.k_proj.weight" + + def attn_k_norm(self, i): + return f"model.language_model.layers.{i}.self_attn.k_norm.weight" + + def attn_o_proj(self, i): + return f"model.language_model.layers.{i}.self_attn.o_proj.weight" + + def attn_v_proj(self, i): + return f"model.language_model.layers.{i}.self_attn.v_proj.weight" + + def mlp_norm(self, i): + return f"model.language_model.layers.{i}.post_attention_layernorm.weight" + + def mlp_gate(self, i): + return f"model.language_model.layers.{i}.mlp.gate_proj.weight" + + def mlp_down(self, i): + return f"model.language_model.layers.{i}.mlp.down_proj.weight" + + def mlp_up(self, i): + return f"model.language_model.layers.{i}.mlp.up_proj.weight" + +class Qwen3vlVisWeightsNaming: + def patch_embed_weight(self): + return "model.visual.patch_embed.proj.weight" + def patch_embed_bias(self): + return "model.visual.patch_embed.proj.bias" + def pos_embed_weight(self): + return "model.visual.pos_embed.weight" + def attn_proj_weight(self,i): + return f"model.visual.blocks.{i}.attn.proj.weight" + def attn_proj_bias(self,i): + return f"model.visual.blocks.{i}.attn.proj.bias" + def attn_qkv_weight(self,i): + return f"model.visual.blocks.{i}.attn.qkv.weight" + def attn_qkv_bias(self,i): + return f"model.visual.blocks.{i}.attn.qkv.bias" + def mlp_linear_fc1_weight(self,i): + return f"model.visual.blocks.{i}.mlp.linear_fc1.weight" + def mlp_linear_fc1_bias(self,i): + return f"model.visual.blocks.{i}.mlp.linear_fc1.bias" + def mlp_linear_fc2_weight(self,i): + return f"model.visual.blocks.{i}.mlp.linear_fc2.weight" + def mlp_linear_fc2_bias(self,i): + return f"model.visual.blocks.{i}.mlp.linear_fc2.bias" + def norm1_weight(self,i): + return f"model.visual.blocks.{i}.norm1.weight" + def norm1_bias(self,i): + return f"model.visual.blocks.{i}.norm1.bias" + def norm2_weight(self,i): + return f"model.visual.blocks.{i}.norm2.weight" + def norm2_bias(self,i): + return f"model.visual.blocks.{i}.norm2.bias" + def deepstack_merger_linear_fc1_weight(self,i): + return f"model.visual.deepstack_merger_list.{i}.linear_fc1.weight" + def deepstack_merger_linear_fc1_bias(self,i): + return f"model.visual.deepstack_merger_list.{i}.linear_fc1.bias" + def deepstack_merger_linear_fc2_weight(self,i): + return f"model.visual.deepstack_merger_list.{i}.linear_fc2.weight" + def deepstack_merger_linear_fc2_bias(self,i): + return f"model.visual.deepstack_merger_list.{i}.linear_fc2.bias" + def deepstack_merger_norm_weight(self,i): + return f"model.visual.deepstack_merger_list.{i}.norm.weight" + def deepstack_merger_norm_bias(self,i): + return f"model.visual.deepstack_merger_list.{i}.norm.bias" + + def merger_linear_fc1_weight(self): + return "model.visual.merger.linear_fc1.weight" + def merger_linear_fc1_bias(self): + return "model.visual.merger.linear_fc1.bias" + def merger_linear_fc2_weight(self): + return "model.visual.merger.linear_fc2.weight" + def merger_linear_fc2_bias(self): + return "model.visual.merger.linear_fc2.bias" + def merger_norm_weight(self): + return "model.visual.merger.norm.weight" + def merger_norm_bias(self): + return "model.visual.merger.norm.bias" + +class Qwen3vlMeta(Qwen3vlMetaCStruct): + def __init__(self, config, max_tokens=None): + + if config['text_config']['dtype'] == 'float16': + dt_ = DataType.INFINI_DTYPE_F16 + self.torch_dtype = torch.float16 + elif config['text_config']['dtype'] == 'float32': + dt_ = DataType.INFINI_DTYPE_F32 + self.torch_dtype = torch.float32 + elif config['text_config']['dtype'] == 'bfloat16': + dt_ = DataType.INFINI_DTYPE_BF16 + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"Unsupported text dtype: {config['text_config']['dtype']}") + + super().__init__( + dtype = dt_, + image_token_id = config['image_token_id'], + video_token_id = config['video_token_id'], + vision_end_token_id = config['vision_end_token_id'], + vision_start_token_id = config['vision_start_token_id'], + text_meta = TextMetaCStruct( + bos_token_id = config['text_config']['bos_token_id'], + eos_token_id = config['text_config']['eos_token_id'], + head_dim = config['text_config']['head_dim'], + hidden_size = config['text_config']['hidden_size'], + initializer_range = config['text_config']['initializer_range'], + intermediate_size = config['text_config']['intermediate_size'], + max_tokens = (config['text_config']['max_position_embeddings'] if max_tokens is None else max_tokens), + num_attention_heads = config['text_config']['num_attention_heads'], + num_hidden_layers = config['text_config']['num_hidden_layers'], + num_key_value_heads = config['text_config']['num_key_value_heads'], + rms_norm_eps = config['text_config']['rms_norm_eps'], + mrope_section = (ctypes.c_ulong * 3)(*config['text_config']['rope_scaling']['mrope_section']), + rope_theta = config['text_config']['rope_theta'], + vocab_size = config['text_config']['vocab_size'], + ), + vis_meta = VisMetaCStruct( + depth = config['vision_config']['depth'], + deepstack_visual_indexes = (ctypes.c_ulong * 3)(*config['vision_config']['deepstack_visual_indexes']), + hidden_size = config['vision_config']['hidden_size'], + in_channels = config['vision_config']['in_channels'], + initializer_range = config['vision_config']['initializer_range'], + intermediate_size = config['vision_config']['intermediate_size'], + num_heads = config['vision_config']['num_heads'], + num_position_embeddings = config['vision_config']['num_position_embeddings'], + out_hidden_size = config['vision_config']['out_hidden_size'], + patch_size = config['vision_config']['patch_size'], + spatial_merge_size = config['vision_config']['spatial_merge_size'], + temporal_patch_size = config['vision_config']['temporal_patch_size'] + ) + ) + +def load_specific_tensor(model_dir, tensor_name): + """ + Load a specific tensor from a safetensors model. + Supports both sharded models (with index.json) and single file models. + """ + + # Try to load from individual .safetensors files + safetensors_files = [f for f in os.listdir(model_dir) if f.endswith(".safetensors")] + if not safetensors_files: + raise FileNotFoundError(f"No .safetensors files found in {model_dir}") + + # Try to find the tensor in each file + for filename in safetensors_files: + tensor_file = os.path.join(model_dir, filename) + try: + with safetensors.safe_open(tensor_file, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor = f.get_tensor(tensor_name) + return tensor + except Exception: + continue + + # If we reach here, tensor was not found in any file + raise KeyError(f"{tensor_name} not found in any .safetensors files") + +def load_Qwen3vl_weights( + meta: Qwen3vlMeta, + weights, + model_path: str, + ndev: int, +): + # torch load weights, and reshape for qkv_proj / mlp_gate_up stack, attn / mlp parallel + # weight loader function load from specific offset according to idev, and transpose + model_instance = Qwen3vlModel() + weight_loader = model_instance.create_weight_loader() + vis_names = Qwen3vlVisWeightsNaming() + lang_names = Qwen3vlLangWeightsNaming() + + nkvh = meta.text_meta.num_key_value_heads + nh = meta.text_meta.num_attention_heads + dh = meta.text_meta.head_dim + d = meta.text_meta.hidden_size + di = meta.text_meta.intermediate_size + + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + + # ------------------------------- + # Language_model weights + # ------------------------------- + input_embd = load_specific_tensor(model_path, lang_names.input_embd()).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_input_embd(weights, input_embd.data_ptr()) + del input_embd + + output_norm = load_specific_tensor(model_path, lang_names.output_norm()).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_output_norm(weights, output_norm.data_ptr()) + del output_norm + + output_embd = load_specific_tensor(model_path, lang_names.output_embd()).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_output_embd(weights, output_embd.data_ptr()) + del output_embd + + for i in range(meta.text_meta.num_hidden_layers): + attn_norm = load_specific_tensor(model_path, lang_names.attn_norm(i)).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_attn_norm(weights, attn_norm.data_ptr(), i) + del attn_norm + + attn_q_proj = load_specific_tensor(model_path, lang_names.attn_q_proj(i)) + attn_k_proj = load_specific_tensor(model_path, lang_names.attn_k_proj(i)) + attn_v_proj = load_specific_tensor(model_path, lang_names.attn_v_proj(i)) + + _Q = attn_q_proj.reshape(nh,dh,d) + _K = attn_k_proj.reshape(nkvh,dh,d) + _V = attn_v_proj.reshape(nkvh,dh,d) + + qkv_proj = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + qkv_proj.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :]) + qkv_proj.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + qkv_proj.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + attn_qkv_proj = torch.cat(qkv_proj, dim=0).to(meta.torch_dtype).contiguous() + + weight_loader.contents.lang_loader.load_attn_qkv_proj(weights, attn_qkv_proj.data_ptr(), i) + del attn_qkv_proj + + attn_q_norm = load_specific_tensor(model_path, lang_names.attn_q_norm(i)).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_attn_q_norm(weights, attn_q_norm.data_ptr(), i) + del attn_q_norm + + attn_k_norm = load_specific_tensor(model_path, lang_names.attn_k_norm(i)).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_attn_k_norm(weights, attn_k_norm.data_ptr(), i) + del attn_k_norm + + attn_o_proj = load_specific_tensor(model_path, lang_names.attn_o_proj(i)) + attn_o_proj = attn_o_proj.to(meta.torch_dtype).reshape([d, ndev, nh // ndev * dh]).transpose(0, 1).contiguous() + weight_loader.contents.lang_loader.load_attn_o_proj(weights, attn_o_proj.data_ptr(), i) + del attn_o_proj + + mlp_norm = load_specific_tensor(model_path, lang_names.mlp_norm(i)).to(meta.torch_dtype) + weight_loader.contents.lang_loader.load_mlp_norm(weights, mlp_norm.data_ptr(), i) + del mlp_norm + + mlp_gate = load_specific_tensor(model_path, lang_names.mlp_gate(i)) + mlp_up = load_specific_tensor(model_path, lang_names.mlp_up(i)) + + gate_up = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + gate_up.append(mlp_gate[_start:_end, :]) + gate_up.append(mlp_up[_start:_end, :]) + mlp_gate_up = torch.cat(gate_up, dim=0).to(meta.torch_dtype).contiguous() + + weight_loader.contents.lang_loader.load_mlp_gate_up(weights, mlp_gate_up.data_ptr(), i) + del mlp_gate_up + + mlp_down = load_specific_tensor(model_path, lang_names.mlp_down(i)) + mlp_down = mlp_down.to(meta.torch_dtype).reshape([d, ndev, di // ndev]).transpose(0, 1).contiguous() + weight_loader.contents.lang_loader.load_mlp_down(weights, mlp_down.data_ptr(), i) + del mlp_down + + # ------------------------------- + # Vision head weights + # ------------------------------- + patch_embed_weight = load_specific_tensor(model_path, vis_names.patch_embed_weight()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_patch_embed_weight(weights, patch_embed_weight.data_ptr()) + del patch_embed_weight + + patch_embed_bias = load_specific_tensor(model_path, vis_names.patch_embed_bias()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_patch_embed_bias(weights, patch_embed_bias.data_ptr()) + del patch_embed_bias + + pos_embed_weight = load_specific_tensor(model_path, vis_names.pos_embed_weight()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_pos_embed_weight(weights, pos_embed_weight.data_ptr()) + del pos_embed_weight + + for i in range(meta.vis_meta.depth): + attn_proj_weight = load_specific_tensor(model_path, vis_names.attn_proj_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_attn_proj_weight(weights, attn_proj_weight.data_ptr(), i) + del attn_proj_weight + + attn_proj_bias = load_specific_tensor(model_path, vis_names.attn_proj_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_attn_proj_bias(weights, attn_proj_bias.data_ptr(), i) + del attn_proj_bias + + attn_qkv_weight = load_specific_tensor(model_path, vis_names.attn_qkv_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_attn_qkv_weight(weights, attn_qkv_weight.data_ptr(), i) + del attn_qkv_weight + + attn_qkv_bias = load_specific_tensor(model_path, vis_names.attn_qkv_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_attn_qkv_bias(weights, attn_qkv_bias.data_ptr(), i) + del attn_qkv_bias + + mlp_linear_fc1_weight = load_specific_tensor(model_path, vis_names.mlp_linear_fc1_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_mlp_linear_fc1_weight(weights, mlp_linear_fc1_weight.data_ptr(), i) + del mlp_linear_fc1_weight + + mlp_linear_fc1_bias = load_specific_tensor(model_path, vis_names.mlp_linear_fc1_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_mlp_linear_fc1_bias(weights, mlp_linear_fc1_bias.data_ptr(), i) + del mlp_linear_fc1_bias + + mlp_linear_fc2_weight = load_specific_tensor(model_path, vis_names.mlp_linear_fc2_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_mlp_linear_fc2_weight(weights, mlp_linear_fc2_weight.data_ptr(), i) + del mlp_linear_fc2_weight + + mlp_linear_fc2_bias = load_specific_tensor(model_path, vis_names.mlp_linear_fc2_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_mlp_linear_fc2_bias(weights, mlp_linear_fc2_bias.data_ptr(), i) + del mlp_linear_fc2_bias + + norm1_weight = load_specific_tensor(model_path, vis_names.norm1_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_norm1_weight(weights, norm1_weight.data_ptr(), i) + del norm1_weight + + norm1_bias = load_specific_tensor(model_path, vis_names.norm1_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_norm1_bias(weights, norm1_bias.data_ptr(), i) + del norm1_bias + + norm2_weight = load_specific_tensor(model_path, vis_names.norm2_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_norm2_weight(weights, norm2_weight.data_ptr(), i) + del norm2_weight + + norm2_bias = load_specific_tensor(model_path, vis_names.norm2_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_norm2_bias(weights, norm2_bias.data_ptr(), i) + del norm2_bias + + for i in range(len(meta.vis_meta.deepstack_visual_indexes)): + deepstack_merger_linear_fc1_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc1_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_weight(weights, deepstack_merger_linear_fc1_weight.data_ptr(), i) + del deepstack_merger_linear_fc1_weight + + deepstack_merger_linear_fc1_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc1_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_bias(weights, deepstack_merger_linear_fc1_bias.data_ptr(), i) + del deepstack_merger_linear_fc1_bias + + deepstack_merger_linear_fc2_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc2_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_weight(weights, deepstack_merger_linear_fc2_weight.data_ptr(), i) + del deepstack_merger_linear_fc2_weight + + deepstack_merger_linear_fc2_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc2_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_bias(weights, deepstack_merger_linear_fc2_bias.data_ptr(), i) + del deepstack_merger_linear_fc2_bias + + deepstack_merger_norm_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_norm_weight(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_norm_weight(weights, deepstack_merger_norm_weight.data_ptr(), i) + del deepstack_merger_norm_weight + + deepstack_merger_norm_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_norm_bias(i)).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_deepstack_merger_norm_bias(weights, deepstack_merger_norm_bias.data_ptr(), i) + del deepstack_merger_norm_bias + + merger_linear_fc1_weight = load_specific_tensor(model_path, vis_names.merger_linear_fc1_weight()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_linear_fc1_weight(weights, merger_linear_fc1_weight.data_ptr()) + del merger_linear_fc1_weight + + merger_linear_fc1_bias = load_specific_tensor(model_path, vis_names.merger_linear_fc1_bias()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_linear_fc1_bias(weights, merger_linear_fc1_bias.data_ptr()) + del merger_linear_fc1_bias + + merger_linear_fc2_weight = load_specific_tensor(model_path, vis_names.merger_linear_fc2_weight()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_linear_fc2_weight(weights, merger_linear_fc2_weight.data_ptr()) + del merger_linear_fc2_weight + + merger_linear_fc2_bias = load_specific_tensor(model_path, vis_names.merger_linear_fc2_bias()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_linear_fc2_bias(weights, merger_linear_fc2_bias.data_ptr()) + del merger_linear_fc2_bias + + merger_norm_weight = load_specific_tensor(model_path, vis_names.merger_norm_weight()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_norm_weight(weights, merger_norm_weight.data_ptr()) + del merger_norm_weight + + merger_norm_bias = load_specific_tensor(model_path, vis_names.merger_norm_bias()).to(meta.torch_dtype) + weight_loader.contents.vis_loader.load_merger_norm_bias(weights, merger_norm_bias.data_ptr()) + del merger_norm_bias + + +class Qwen3vlBatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(Qwen3vlCacheCStruct) * self.nreq)( + *self.kv_cache_ptrs + ) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + +# 需要处理 visual encoder的cache 和 image video输入 +class Qwen3vlForCauslLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["text_config"]["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + print(model_dir_path) + + if "qwen3_vl" == config["model_type"]: + self.meta = Qwen3vlMeta( + config, max_tokens=max_tokens + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + else: + raise ValueError("Unsupported model architecture") + + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + + self.model_instance = Qwen3vlModel() + weights = self.model_instance.create_weights( + byref(self.meta), + device, + ndev, + dev_ids, + c_bool(True) + ) + print("Loading weights...") + # Load weights from host + load_Qwen3vl_weights(self.meta, weights, model_dir_path, ndev) + # Create model instance + self.model_ptr = self.model_instance.create_model( + byref(self.meta), + weights, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def max_context_len(self): + return self.meta.text_meta.max_tokens + + def create_kv_cache(self): + return self.model_instance.create_cache(self.model_ptr) + + def drop_kv_cache(self, kv_cache): + self.model_instance.drop_cache(self.model_ptr, kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = Qwen3vlBatchedTask(tasks) + self.model_instance.infer_batch( + self.model_ptr, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + + tokens = self.tokenizer.encode(input_content) + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + print(input_content, end="", flush=True) + steps = 0 + total_time = 0 + output_content = "" + + print(tokens) + + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + print(output_tokens) + end_time = time.time() + steps += 1 + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / steps if steps > 0 else -1 + print(f"Time per step: {avg_time:.3f}ms") + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + def destroy_model_instance(self): + self.model_instance.destroy_model(self.model_ptr) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + print( + "Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = Qwen3vlForCauslLM(model_path, device_type, ndev, max_tokens=1024) + model.generate("山东最高的山是?", 50) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/scripts/qwen3vl_test.py b/scripts/qwen3vl_test.py new file mode 100644 index 00000000..893cad6e --- /dev/null +++ b/scripts/qwen3vl_test.py @@ -0,0 +1,102 @@ +import torch +from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, GenerationConfig +import os +import time + +# 加载模型和processor +# 修改为使用Qwen3VLForConditionalGeneration和AutoProcessor +model = Qwen3VLForConditionalGeneration.from_pretrained( + "/home/user/workshop/Qwen3-VL-2B-Instruct/", + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa", + trust_remote_code=True +) +processor = AutoProcessor.from_pretrained("/home/user/workshop/Qwen3-VL-2B-Instruct/", trust_remote_code=True) + +# 设置生成配置以确保确定性生成 +model.generation_config = GenerationConfig.from_pretrained("/home/user/workshop/Qwen3-VL-2B-Instruct/", trust_remote_code=True) +model.generation_config.do_sample = False # 关闭采样以确保确定性 +model.generation_config.max_new_tokens = 50 + +# 输入消息 - 结合文本和图像(这里仅保留文本示例) +messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "山东最高的山是?" + } + ] + } +] + +# 处理输入 +inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", +) +inputs = {k: v.to(model.device) for k, v in inputs.items()} +inputs.pop("token_type_ids", None) + +print("Input token IDs:", inputs["input_ids"][0].tolist()) +print("Input text:", processor.decode(inputs["input_ids"][0])) + +# 获取输入信息用于逐token生成 +input_ids = inputs["input_ids"] +attention_mask = inputs["attention_mask"] + +# 记录开始生成时的总token数 +initial_length = input_ids.shape[1] +generated_tokens = [] +generation_times = [] + +# 逐token生成 +with torch.no_grad(): + current_input_ids = input_ids + current_attention_mask = attention_mask + + # 获取EOS token ID + eos_token_id = model.generation_config.eos_token_id + + for i in range(model.generation_config.max_new_tokens): + start_time = time.time() + + # 单步生成 + outputs = model( + input_ids=current_input_ids, + attention_mask=current_attention_mask, + ) + + # 获取下一个token + next_token_logits = outputs.logits[:, -1, :] + next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + + # 检查是否达到结束条件 + if next_token_id.item() == eos_token_id: + break + + # 记录生成时间 + end_time = time.time() + generation_times.append((end_time - start_time) * 1000) # 转换为毫秒 + + # 添加到已生成的token中 + generated_tokens.append(next_token_id.item()) + + # 更新输入以包含新生成的token + current_input_ids = torch.cat([current_input_ids, next_token_id], dim=1) + current_attention_mask = torch.cat([current_attention_mask, torch.ones((current_attention_mask.shape[0], 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)], dim=1) + +# 计算平均生成时间 +if generation_times: + avg_generation_time = sum(generation_times) / len(generation_times) + print(f"生成的tokens: {generated_tokens}") + print(f"生成的文本: {processor.decode(generated_tokens, skip_special_tokens=True)}") + print(f"生成的token数量: {len(generated_tokens)}") + print(f"平均生成一个token的时间: {avg_generation_time:.3f} ms") +else: + print("未生成任何新token") \ No newline at end of file diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 4c49e961..a8d1cd99 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -153,6 +153,7 @@ class LRUDescriptorCache { class CacheManager { public: DECLARE_OP_CACHE(Add) + DECLARE_OP_CACHE(Mul) DECLARE_OP_CACHE(RMSNorm) DECLARE_OP_CACHE(Gemm) DECLARE_OP_CACHE(RoPE) @@ -160,11 +161,13 @@ class CacheManager { DECLARE_OP_CACHE(CausalSoftmax) DECLARE_OP_CACHE(Topkrouter) DECLARE_OP_CACHE(SwiGLU) + DECLARE_OP_CACHE(Silu) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(DequantizeAWQ) CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), + Mul_cache(capacity, DESTROY_FUNC(Mul)), RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)), Gemm_cache(capacity, DESTROY_FUNC(Gemm)), RoPE_cache(capacity, DESTROY_FUNC(RoPE)), @@ -172,6 +175,7 @@ class CacheManager { CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), + Silu_cache(capacity, DESTROY_FUNC(Silu)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)) {} diff --git a/src/models/deepseek_v3/deepseek_v3.cpp b/src/models/deepseek_v3/deepseek_v3.cpp index 2c463035..c60ef9d7 100644 --- a/src/models/deepseek_v3/deepseek_v3.cpp +++ b/src/models/deepseek_v3/deepseek_v3.cpp @@ -103,8 +103,8 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc auto batch_pos_ids = std::vector(ntok); size_t req_start = 0; for (uint32_t req = 0; req < nreq; req++) { - for (uint32_t i = 0; i < req_lens[req]; i++) { - batch_pos_ids[req_start + i] = req_pos[req] + i; + for (uint32_t i = 0; i < req_lens[req]; i++) { // req_len 本次query长度,req_pos 历史长度 + batch_pos_ids[req_start + i] = req_pos[req] + i; //batch_pos_ids 展平后每个token的pos } req_start += req_lens[req]; } diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index db5fda11..edf3fd96 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -33,6 +33,27 @@ void InferenceContext::add(std::shared_ptr c, c->data(), a->data(), b->data(), stream)); } +void InferenceContext::mul(std::shared_ptr c, + std::shared_ptr a, + std::shared_ptr b) { + size_t key = CacheManager::createDescriptorKey(c, a, b); + + infiniopMulDescriptor_t desc; + if (!cache_manager->getMulDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateMulDescriptor(op_handle, &desc, c->desc(), a->desc(), b->desc())); + cache_manager->putMulDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetMulWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopMul( + desc, workspace, workspace_size, + c->data(), a->data(), b->data(), stream)); +} + void InferenceContext::rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, @@ -189,6 +210,26 @@ void InferenceContext::swiglu(std::shared_ptr out, out->data(), up->data(), gate->data(), stream)); } +void InferenceContext::silu(std::shared_ptr out, + std::shared_ptr input) { + size_t key = CacheManager::createDescriptorKey(out, input); + + infiniopSiluDescriptor_t desc; + if (!cache_manager->getSiluDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateSiluDescriptor( + op_handle, &desc, out->desc(), input->desc())); + cache_manager->putSiluDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetSiluWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopSilu(desc, workspace, workspace_size, + out->data(), input->data(), stream)); +} + void InferenceContext::randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..76671777 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -19,6 +19,9 @@ struct InferenceContext { void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); + void mul(std::shared_ptr c, + std::shared_ptr a, + std::shared_ptr b); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, @@ -48,6 +51,8 @@ struct InferenceContext { void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); + void silu(std::shared_ptr out, + std::shared_ptr input); void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature); @@ -81,6 +86,10 @@ inline void add(std::shared_ptr c, std::shared_ptr a, std::share getInferenceContext().add(c, a, b); } +inline void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { + getInferenceContext().mul(c, a, b); +} + inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon) { getInferenceContext().rmsnorm(y, x, w, epsilon); @@ -131,6 +140,10 @@ inline void swiglu(std::shared_ptr out, std::shared_ptr up, getInferenceContext().swiglu(out, up, gate); } +inline void silu(std::shared_ptr out, std::shared_ptr input) { + getInferenceContext().silu(out, input); +} + inline void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 41f8e5ea..b599d0d8 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -238,11 +238,13 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rearrange(q_rearrange->slice(2, 0, seq_len), q); auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + // [nkvh, ngroup * seq_len, dh] @ [nkvh, dh, total_len] = [nkvh, ngroup * seq_len, total_len] linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); // softmax auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); causalSoftmax(qk_softmax, qk_softmax); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + // [nkvh, ngroup * seq_len, total_len] @ [nkvh, total_len, dh] = [nkvh, ngroup * seq_len, dh] linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); // rearrange attn val rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); diff --git a/src/models/qwen3vl/qwen3vl.cpp b/src/models/qwen3vl/qwen3vl.cpp new file mode 100644 index 00000000..d784b15e --- /dev/null +++ b/src/models/qwen3vl/qwen3vl.cpp @@ -0,0 +1,415 @@ +#include "qwen3vl_impl.hpp" + +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include "../inference_context.hpp" +#include "infinicore_infer.h" + +#include +#include +#include + +void createDeviceResource(Qwen3vlDeviceResource *rsrc, const Qwen3vlMeta *meta, + std::shared_ptr weights, + infiniDevice_t device, int idev, + int ndev, int dev_id, + infinicclComm_t comm) { + RUN_INFINI(infinirtSetDevice(device, dev_id)); + RUN_INFINI(infinirtStreamSynchronize(weights->load_stream)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + auto memory_pool = std::make_shared(); + + *rsrc = Qwen3vlDeviceResource{ + device, + dev_id, + handle, + weights, + stream, + comm, + memory_pool, + }; + RUN_INFINI(infinirtDeviceSynchronize()); +} + +void releaseDeviceResource(Qwen3vlDeviceResource &res) { + infinirtDeviceSynchronize(); + + res.weights.reset(); + + infiniopDestroyHandle(res.handle); + res.handle = nullptr; + infinirtStreamDestroy(res.stream); + res.stream = nullptr; + infinicclCommDestroy(res.comm); + res.comm = nullptr; +} + +//todo: +// pd分离 +// flashattn + batching +// triron跨平台 +// pageattn + + +void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *last_logits) { + assert(meta.text_meta.num_attention_heads % ndev == 0); + assert(meta.text_meta.num_key_value_heads % ndev == 0); + + auto dtype = meta.dtype; + printf("meta dtype: %d \n",dtype); + auto nlayer = meta.text_meta.num_hidden_layers; + size_t nh = meta.text_meta.num_attention_heads / size_t(ndev); + size_t nkvh = meta.text_meta.num_key_value_heads / size_t(ndev); + auto ngroup = nh / nkvh; + auto dh = meta.text_meta.head_dim; + auto d = meta.text_meta.hidden_size; + auto di = meta.text_meta.intermediate_size / size_t(ndev); + auto dvoc = meta.text_meta.vocab_size; + float epsilon = meta.text_meta.rms_norm_eps; + auto stream = rsrc.stream; + auto weights = rsrc.weights; + + //Allocate buffers + auto logits_in = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool); + auto logits_out = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool); + + //所有请求的当前token + auto qkv_buf = Tensor::buffer(dtype, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + auto o_buf = Tensor::buffer(dtype, {ntok, nh * dh}, rsrc.memory_pool); + auto gate_up_buf = Tensor::buffer(dtype, {ntok, 2*di}, rsrc.memory_pool); + + auto prob_buf = Tensor::buffer(dtype, {nreq, dvoc}, rsrc.memory_pool); + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); + auto result_cpu = std::vector(nreq); + + //Prepare inputs + auto batch_pos_ids = std::vector(ntok); + size_t req_start = 0; + for (uint32_t req = 0; req < nreq; req++) { + for (uint32_t i = 0; i < req_lens[req]; i++) { // req_len 本次query长度,req_pos 历史长度 + batch_pos_ids[req_start + i] = req_pos[req] + i; //batch_pos_ids 展平后每个token的pos + } + req_start += req_lens[req]; + } + std::shared_ptr pos_ids_buf; + if (rsrc.device == INFINI_DEVICE_CPU) { + pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); + } else { + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, + INFINIRT_MEMCPY_H2D, stream)); + } + + //convert tokens to embeddings + for (uint32_t i = 0; i < ntok; i++) { + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + weights->w_lang->in_embd->data(tokens[i] * d), + dsize(dtype) * d, INFINIRT_MEMCPY_D2D, stream)); + } + + // attention inner + size_t max_qk_size = 0; + size_t max_seq_len = 0; + + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + + max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len)); + max_seq_len = std::max(max_seq_len, size_t(seq_len)); + } + + auto attn_score_buf = Tensor::buffer(dtype, {nh * max_qk_size}, rsrc.memory_pool); + auto attn_val_buf = Tensor::buffer(dtype, {nh, max_seq_len, dh}, rsrc.memory_pool); + auto rearrange_q_buf = Tensor::buffer(dtype, {nkvh, ngroup, max_seq_len, dh}, rsrc.memory_pool); + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_rope->slice(1,0,nh); + auto k_buf = qkv_rope->slice(1,nh,nkvh); + + auto gate_buf = gate_up_buf->slice(1, 0, di); + auto up_buf = gate_up_buf->slice(1, di, di); + + //Compute + for (uint32_t i = 0; i < nlayer; i++){ + // attn norm + rmsnorm(logits_out,logits_in,weights->w_lang->layers[i].attn_norm,epsilon); + // qkv_proj + linear(qkv_buf,logits_out,weights->w_lang->layers[i].attn_qkv_proj,1.0,0.0,nullptr,nullptr); + // qk_norm + rmsnorm(q_buf,q_buf,weights->w_lang->layers[i].attn_q_norm,epsilon); + rmsnorm(k_buf,k_buf,weights->w_lang->layers[i].attn_k_norm,epsilon); + // rope + rope_v2(q_buf,q_buf,pos_ids_buf,weights->sin_table,weights->cos_table); + rope_v2(k_buf,k_buf,pos_ids_buf,weights->sin_table,weights->cos_table); + + // 逐个req处理 + size_t token_offset = 0; + for(uint32_t req=0; req < nreq; req++){ + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + + auto o = o_buf->slice(0,token_offset,seq_len)->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});// [nkvh, ngroup, seq_len, dh] + auto q = qkv_rope->slice({{0,token_offset,seq_len},{1,0,nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});// [nkvh, ngroup, seq_len, dh] + auto k = qkv_rope->slice({{0,token_offset,seq_len},{1,nh,nkvh}});// [ntok, nkvh, dh] + auto v = qkv_rope->slice({{0,token_offset,seq_len},{1,nh+nkvh,nkvh}});// [ntok, nkvh, dh] + + // concat to cache + rearrange(caches[req]->k_rot[idev][i]->slice(0,past_len,seq_len),k); + rearrange(caches[req]->v[idev][i]->slice(0,past_len,seq_len),v); + + //fill full_k full_v + auto full_k_buff = caches[req]->k_rot[idev][i]->slice(0,0,total_len)->permute({1,2,0});// [nkvh, dh, total_len] + auto full_v_buff = caches[req]->v[idev][i]->slice(0,0,total_len)->permute({1,0,2});// [nkvh, total_len, dh] + + //self-attn + auto attn_score_req = attn_score_buf->slice(0,0,nh*seq_len*total_len)->view({nkvh, ngroup*seq_len, total_len}); + auto rearrange_q = rearrange_q_buf->slice(2,0,seq_len); + rearrange(rearrange_q,q); + // [nkvh, ngroup * seq_len, dh] @ [nkvh, dh, total_len] = [nkvh, ngroup * seq_len, total_len] + linear(attn_score_req,rearrange_q->view({nkvh, ngroup * seq_len, dh}),full_k_buff,1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = attn_score_req->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax,qk_softmax); + auto attn_val_req = attn_val_buf->slice(1,0,seq_len)->view({nkvh, ngroup * seq_len, dh}); + // [nkvh, ngroup * seq_len, total_len] @ [nkvh, total_len, dh] = [nkvh, ngroup * seq_len, dh] + linear(attn_val_req, attn_score_req, full_v_buff, 1.0, 0.0, nullptr, nullptr); + rearrange(o,attn_val_req->view({nkvh, ngroup, seq_len, dh})); + token_offset += seq_len; + } + linear(logits_in, o_buf, weights->w_lang->layers[i].attn_o_proj, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); + // All_reduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + logits_in->data(), logits_in->data(), ntok * d, dtype, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + + // mlp norm + rmsnorm(logits_out,logits_in,weights->w_lang->layers[i].mlp_norm,epsilon); + // mlp gate_up + linear(gate_up_buf,logits_out,weights->w_lang->layers[i].mlp_gate_up,1.0,0.0,nullptr,nullptr); + // silu + silu(gate_buf,gate_buf); + mul(gate_buf,gate_buf,up_buf); + // mlp down + linear(logits_in,gate_buf,weights->w_lang->layers[i].mlp_down,1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); + // All_reduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + logits_in->data(), logits_in->data(), ntok * d, dtype, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + } + // sample and output + if (idev == 0) { + if (last_logits != nullptr) { + rmsnorm(logits_out, logits_in, weights->w_lang->out_norm, meta.text_meta.rms_norm_eps); + auto last_logits_buf = Tensor::buffer(dtype, {ntok, dvoc}, rsrc.memory_pool); + linear(last_logits_buf, logits_out, weights->w_lang->out_embd, 1.0, 0.0, nullptr, nullptr); + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dtype) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + } + if (output != nullptr) { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + token_offset += seq_len; + rmsnorm(logits_out->slice(0, req, 1), + logits_in->slice(0, token_offset - 1, 1), + weights->w_lang->out_norm, + meta.text_meta.rms_norm_eps); + } + //logits_out->slice(0,0,nreq)->debug(); different from transformers + linear(prob_buf, logits_out->slice(0, 0, nreq), weights->w_lang->out_embd, 1.0, 0.0, nullptr, nullptr); + std::random_device _rd; + std::mt19937 gen(_rd()); + token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + float random_val = std::uniform_real_distribution(0, 1)(gen); + randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), + prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), + random_val, topp[req], topk[req], temperature[req]); + token_offset += seq_len; + } + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), + sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); + for (uint32_t req = 0; req < nreq; req++) { + output[req] = uint32_t(result_cpu[req]); + } + } + } +} + +__C void +inferBatchQwen3vl(struct Qwen3vlModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +forwardBatchQwen3vl(struct Qwen3vlModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **kv_caches, + void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +void launchDevice(const Qwen3vlMeta &meta, std::shared_ptr weights, Qwen3vlDeviceResource *rsrc, InferState &state, InferRequest &req, + infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { + // Create Device Resource + createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); + + CacheManager cache_manager(100); + InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); + + // Set the inference context for this thread + setInferenceContext(&ctx); + + { + std::unique_lock lock(state.mtx); + state.loaded = true; + lock.unlock(); + state.cv_load.notify_one(); + } + + // Infer Loop + while (true) { + std::unique_lock lock(state.mtx); + state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); + // quit if exit_flag is set + if (state.exit_flag) { + break; + } + + inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, + req.req_lens, req.nreq, req.req_pos, req.kv_caches, + req.temperature, req.topk, req.topp, req.output, req.logits); + + state.proceed = false; + lock.unlock(); + state.cv_done.notify_one(); + } + + // Clean-Up + releaseDeviceResource(*rsrc); + setInferenceContext(nullptr); // Clear the context when done +} + + +Qwen3vlModel::Qwen3vlModel(const Qwen3vlMeta *_meta, const Qwen3vlWeights *weights) : meta(*_meta) { + auto device_weights = weights->device_weights; + int ndev = device_weights.size(); + device = device_weights[0]->device; + dev_ids.resize(ndev); + for (int i = 0; i < ndev; i++) { + dev_ids[i] = device_weights[i]->dev_id; + } + dev_resources = std::vector(ndev); + states = std::vector(ndev); + threads.resize(ndev); + RUN_INFINI(infinirtInit()); + auto comms = std::vector(ndev, nullptr); + if (ndev > 1) { + RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); + } + for (int i = 0; i < ndev; i++) { + threads[i] = std::thread(launchDevice, std::cref(meta), device_weights[i], &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); + } + for (int i = 0; i < ndev; i++) { + std::unique_lock lock(states[i].mtx); + states[i].cv_load.wait(lock, [&] { return states[i].loaded; }); + lock.unlock(); + } +} + +__C struct Qwen3vlModel * +createQwen3vlModel(const Qwen3vlMeta *_meta, + const Qwen3vlWeights *weights) { + Qwen3vlModel *model = new Qwen3vlModel(_meta, weights); + return model; +} + +__C void +destroyQwen3vlModel(struct Qwen3vlModel *model) { + auto ndev = model->dev_resources.size(); + + for (size_t idev = 0; idev < ndev; idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].exit_flag = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + + for (size_t idev = 0; idev < ndev; idev++) { + model->threads[idev].join(); + } + + delete model; +} diff --git a/src/models/qwen3vl/qwen3vl_cache.cpp b/src/models/qwen3vl/qwen3vl_cache.cpp new file mode 100644 index 00000000..b10b86c8 --- /dev/null +++ b/src/models/qwen3vl/qwen3vl_cache.cpp @@ -0,0 +1,43 @@ +#include "qwen3vl_impl.hpp" + +__C struct Qwen3vlCache * +createQwen3vlCache(const struct Qwen3vlModel *model) { + Qwen3vlCache *cache = new Qwen3vlCache(); + auto ndev = model->dev_resources.size(); + auto nlayer = model->meta.text_meta.num_hidden_layers; + auto max_len = model->meta.text_meta.max_tokens; + auto dh = model->meta.text_meta.head_dim; + auto nkv = model->meta.text_meta.num_key_value_heads / size_t(ndev); + auto k_rot_shape = std::vector{max_len, nkv, dh}; + auto v_shape = std::vector{max_len, nkv, dh}; + for (size_t idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + auto k_rot_cache = std::vector>(); + auto v_cache = std::vector>(); + for (size_t layer = 0; layer < nlayer; layer++) { + k_rot_cache.push_back(std::move(Tensor::buffer(model->meta.dtype, k_rot_shape))); + v_cache.push_back(std::move(Tensor::buffer(model->meta.dtype, v_shape))); + } + cache->k_rot.push_back(k_rot_cache); + cache->v.push_back(v_cache); + } + + return cache; +} + +//////还有visual deepstack需要cache? + +__C void +dropQwen3vlCache(const struct Qwen3vlModel *model, + struct Qwen3vlCache *cache) { + auto ndev = model->dev_resources.size(); + auto nlayer = model->meta.text_meta.num_hidden_layers; + for (size_t idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + for (size_t layer = 0; layer < nlayer; layer++) { + cache->k_rot[idev][layer].reset(); + cache->v[idev][layer].reset(); + } + } + delete cache; +} \ No newline at end of file diff --git a/src/models/qwen3vl/qwen3vl_impl.hpp b/src/models/qwen3vl/qwen3vl_impl.hpp new file mode 100644 index 00000000..7e126f6c --- /dev/null +++ b/src/models/qwen3vl/qwen3vl_impl.hpp @@ -0,0 +1,130 @@ +#ifndef QWEN3VL_IMPL_H +#define QWEN3VL_IMPL_H + +#include "infinicore_infer.h" + +#include "../../allocator.hpp" +#include "../../tensor.hpp" + +#include +#include +#include +#include +#include + +struct LayerWeight { + std::shared_ptr attn_norm; + std::shared_ptr attn_qkv_proj; + std::shared_ptr attn_q_norm; + std::shared_ptr attn_k_norm; + std::shared_ptr attn_o_proj; + + std::shared_ptr mlp_norm; + std::shared_ptr mlp_down, mlp_gate_up; +}; + +struct LanguageModelWeight { + std::shared_ptr in_embd, out_embd, out_norm; + std::vector layers; +}; + +struct VisBlockWeight { + std::shared_ptr attn_proj_weight, attn_proj_bias, attn_qkv_weight, attn_qkv_bias; + std::shared_ptr mlp_linear_fc1_weight, mlp_linear_fc1_bias, mlp_linear_fc2_weight, mlp_linear_fc2_bias; + std::shared_ptr norm1_weight, norm1_bias, norm2_weight, norm2_bias; +}; + +struct DeepstackMergerWeight { + std::shared_ptr linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias; + std::shared_ptr norm_weight, norm_bias; +}; + +struct MergerWeight { + std::shared_ptr linear_fc1_weight, linear_fc1_bias, linear_fc2_weight, linear_fc2_bias; + std::shared_ptr norm_weight, norm_bias; +}; + + +struct VisualEncoderWeight { + std::shared_ptr patch_embed_weight, patch_embed_bias, pos_embed_weight; + std::vector blocks; + std::vector deepstack_mergers; + std::shared_ptr merger; +}; + + +struct Qwen3vlDeviceWeights { + std::shared_ptr sin_table,cos_table; + std::shared_ptr w_lang; + std::shared_ptr w_vis; + infiniDevice_t device; + int dev_id; + infinirtStream_t load_stream; +}; + +struct Qwen3vlWeights { + Qwen3vlMeta const *meta; + bool transpose_weight; + std::vector> device_weights; + + Qwen3vlWeights(const Qwen3vlMeta *meta, + infiniDevice_t device, + int ndev, + const int *dev_ids, + bool transpose_weight); +}; + +struct Qwen3vlDeviceResource { + // Device + infiniDevice_t device; + int device_id; + infiniopHandle_t handle; + // Weights + std::shared_ptr weights; + // Streams + infinirtStream_t stream; + // Communicator + infinicclComm_t comm; + + std::shared_ptr memory_pool; +}; + +struct InferState { + std::mutex mtx; + std::condition_variable cv_load, cv_start, cv_done; + bool loaded = false; + bool proceed = false; + bool exit_flag = false; +}; + +struct InferRequest { + const uint32_t *tokens; + uint32_t ntok; + const uint32_t *req_lens; + uint32_t nreq; + const uint32_t *req_pos; + struct Qwen3vlCache **kv_caches; + const float *temperature; + const uint32_t *topk; + const float *topp; + uint32_t *output; + void *logits; +}; + +struct Qwen3vlModel { + Qwen3vlMeta meta; + infiniDevice_t device; + std::vector dev_ids; + std::vector dev_resources; + std::vector states; + std::vector threads; + InferRequest req; + + Qwen3vlModel(const Qwen3vlMeta *, const Qwen3vlWeights *weights); +}; + +struct Qwen3vlCache { + std::vector>> k_rot, v; +}; + +#endif \ No newline at end of file diff --git a/src/models/qwen3vl/qwen3vl_weight.cpp b/src/models/qwen3vl/qwen3vl_weight.cpp new file mode 100644 index 00000000..f525b30e --- /dev/null +++ b/src/models/qwen3vl/qwen3vl_weight.cpp @@ -0,0 +1,634 @@ +#include "qwen3vl_impl.hpp" + +#include + +inline std::shared_ptr getInEmbd( + const Qwen3vlMeta *meta) { + auto shape = std::vector({meta->text_meta.vocab_size, meta->text_meta.hidden_size}); + return Tensor::weight(nullptr, meta->dtype, shape); +} + +inline std::shared_ptr getOutNorm( + const Qwen3vlMeta *meta) { + auto shape = std::vector({meta->text_meta.hidden_size}); + return Tensor::weight(nullptr, meta->dtype, shape); +} + +inline std::shared_ptr getOutEmbd( + const Qwen3vlMeta *meta) { + + auto shape = std::vector({meta->text_meta.vocab_size, meta->text_meta.hidden_size}); + return Tensor::weight(nullptr, meta->dtype, shape) + ->permute({1, 0}); +} + +inline void getLayerWeight( + const Qwen3vlMeta *meta,LayerWeight& layer, int ndev) { + auto nkvh = meta->text_meta.num_key_value_heads; + auto nh = meta->text_meta.num_attention_heads; + auto dh = meta->text_meta.head_dim; + auto d = meta->text_meta.hidden_size; + auto di = meta->text_meta.intermediate_size; + + auto dh_shape = std::vector({meta->text_meta.hidden_size}); + layer.attn_norm = Tensor::weight(nullptr, meta->dtype, dh_shape); + auto qk_norm_shape = std::vector({meta->text_meta.head_dim}); + layer.attn_q_norm = Tensor::weight(nullptr, meta->dtype, qk_norm_shape); + layer.attn_k_norm = Tensor::weight(nullptr, meta->dtype, qk_norm_shape); + auto qkv_proj_shape = std::vector({(nh + 2 * nkvh) / ndev * dh, d}); + layer.attn_qkv_proj = Tensor::weight(nullptr, meta->dtype, qkv_proj_shape); + auto o_proj_shape = std::vector({d, nh / ndev * dh}); + layer.attn_o_proj = Tensor::weight(nullptr, meta->dtype, o_proj_shape); + + layer.mlp_norm = Tensor::weight(nullptr, meta->dtype, dh_shape); + auto up_shape = std::vector({2 * di / ndev, d}); + layer.mlp_gate_up = Tensor::weight(nullptr, meta->dtype, up_shape); + auto down_shape = std::vector({d, di / ndev}); + layer.mlp_down = Tensor::weight(nullptr, meta->dtype, down_shape); +} + + +inline void getVisualWeight( + const Qwen3vlMeta *meta, std::shared_ptr w_vis) { + Qwen3vlVisMeta vis_meta = meta->vis_meta; + auto patch_embed_shape = std::vector({vis_meta.hidden_size , vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size}); + w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape); + w_vis->patch_embed_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->pos_embed_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.num_position_embeddings, vis_meta.hidden_size}); + w_vis->merger = std::make_shared(); + w_vis->merger->linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.intermediate_size}); + w_vis->merger->linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size, vis_meta.intermediate_size}); + w_vis->merger->linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); + w_vis->merger->linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size}); + w_vis->merger->norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks = std::vector(vis_meta.depth); + for (size_t i = 0; i < vis_meta.depth; i++) { + w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size,vis_meta.hidden_size}); + w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks[i].attn_qkv_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size,vis_meta.hidden_size}); + w_vis->blocks[i].attn_qkv_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size}); + w_vis->blocks[i].mlp_linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.hidden_size}); + w_vis->blocks[i].mlp_linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); + w_vis->blocks[i].mlp_linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.intermediate_size}); + w_vis->blocks[i].mlp_linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks[i].norm1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks[i].norm1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks[i].norm2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + w_vis->blocks[i].norm2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); + } + w_vis->deepstack_mergers = std::vector(3); + for (size_t i = 0; i < 3; i++){ + w_vis->deepstack_mergers[i].linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size,vis_meta.intermediate_size}); + w_vis->deepstack_mergers[i].linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size,vis_meta.intermediate_size}); + w_vis->deepstack_mergers[i].linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); + w_vis->deepstack_mergers[i].linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size}); + w_vis->deepstack_mergers[i].norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); + w_vis->deepstack_mergers[i].norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); + } + +} + + +inline std::shared_ptr getSinTable(const Qwen3vlMeta *meta) { + auto half_dh = meta->text_meta.head_dim / 2; + auto unit = dsize(meta->dtype); + void *table = std::malloc(meta->text_meta.max_tokens * half_dh * unit); + + for (size_t i = 0; i < meta->text_meta.max_tokens; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _sin = std::sin( + static_cast(i) / std::pow(meta->text_meta.rope_theta, static_cast(j) / half_dh)); + if (meta->dtype == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); + } else if (meta->dtype == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin); + } else if (meta->dtype == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _sin; + } else { + std::cout << "unsupported data type" << std::endl; + exit(1); + } + } + } + auto shape = std::vector({meta->text_meta.max_tokens, half_dh}); + auto tensor = Tensor::weight(table, meta->dtype, shape); + std::free(table); + return tensor; +} + +inline std::shared_ptr getCosTable(const Qwen3vlMeta *meta) { + auto half_dh = meta->text_meta.head_dim / 2; + auto unit = dsize(meta->dtype); + void *table = std::malloc(meta->text_meta.max_tokens * half_dh * unit); + + for (size_t i = 0; i < meta->text_meta.max_tokens; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _cos = std::cos( + static_cast(i) / std::pow(meta->text_meta.rope_theta, static_cast(j) / half_dh)); + if (meta->dtype == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); + } else if (meta->dtype == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos); + } else if (meta->dtype == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _cos; + } else { + std::cout << "unsupported data type" << std::endl; + exit(1); + } + } + } + auto shape = std::vector({meta->text_meta.max_tokens, half_dh}); + auto tensor = Tensor::weight(table, meta->dtype, shape); + std::free(table); + return tensor; +} + +Qwen3vlWeights::Qwen3vlWeights( + const Qwen3vlMeta *_meta, infiniDevice_t device, int ndev, const int *dev_ids, bool _transpose_weight) { + meta = _meta; + transpose_weight = _transpose_weight; + device_weights = std::vector>(ndev); + for (int dev = 0; dev < ndev; dev++) { + int dev_id = dev_ids[dev]; + RUN_INFINI(infinirtSetDevice(device, dev_id)); + device_weights[dev] = std::make_shared(); + device_weights[dev]->device = device; + device_weights[dev]->dev_id = dev_id; + RUN_INFINI(infinirtStreamCreate(&device_weights[dev]->load_stream)); + + device_weights[dev]->w_lang = std::make_shared(); + device_weights[dev]->w_vis = std::make_shared(); + + device_weights[dev]->w_lang->in_embd = getInEmbd(meta); + device_weights[dev]->w_lang->out_norm = getOutNorm(meta); + device_weights[dev]->w_lang->out_embd = getOutEmbd(meta); + device_weights[dev]->sin_table = getSinTable(meta); + device_weights[dev]->cos_table = getCosTable(meta); + + device_weights[dev]->w_lang->layers = std::vector(meta->text_meta.num_hidden_layers); + + for (size_t layer = 0; layer < meta->text_meta.num_hidden_layers; layer++) { + getLayerWeight(meta, device_weights[dev]->w_lang->layers[layer], ndev); + } + + getVisualWeight(meta, device_weights[dev]->w_vis); + + } +} + +//--- Lang Global +void load_input_embd(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading input embedding from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->in_embd->load(cpu_ptr, weight->load_stream); + } +} + +void load_output_norm(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading output norm from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->out_norm->load(cpu_ptr, weight->load_stream); + } +} + +void load_output_embd(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading output embedding from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->out_embd->load(cpu_ptr, weight->load_stream); + if(weights->transpose_weight) { + weight->w_lang->out_embd->permute({1,0}); //[d,voc] + } + } +} + +// --- Attention +void load_attn_norm(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading attention norm " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].attn_norm->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_q_norm(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading attention q_norm " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].attn_q_norm->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_qkv_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading attention q_proj " << layer << " from " << cpu_ptr << std::endl; + int ndev = int(weights->device_weights.size()); + auto nkvh = weights->meta->text_meta.num_key_value_heads; + auto nh = weights->meta->text_meta.num_attention_heads; + auto dh = weights->meta->text_meta.head_dim; + auto d = weights->meta->text_meta.hidden_size; + //[ndev,nh+2*nkvh,dh,d] + for (int idev = 0; idev < ndev; idev++) { + auto weight = weights->device_weights[idev]; + size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(weights->meta->dtype); + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].attn_qkv_proj->load((char *)cpu_ptr + offset, weight->load_stream); + if(weights->transpose_weight) { + weight->w_lang->layers[layer].attn_qkv_proj = + weight->w_lang->layers[layer].attn_qkv_proj->permute({1,0}); //[d, (nh+2*nkvh)*dh] + } + } +} + +void load_attn_k_norm(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading attention k_norm " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].attn_k_norm->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_o_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading attention o_proj " << layer << " from " << cpu_ptr << std::endl; + int ndev = int(weights->device_weights.size()); + auto nh = weights->meta->text_meta.num_attention_heads; + auto dh = weights->meta->text_meta.head_dim; + auto d = weights->meta->text_meta.hidden_size; + // [ndev, d, nh // ndev * dh] + for (int idev = 0; idev < ndev; idev++) { + auto weight = weights->device_weights[idev]; + size_t offset = idev * d * (nh / ndev * dh) * dsize(weights->meta->dtype); + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].attn_o_proj->load((char *)cpu_ptr + offset, weight->load_stream); + if(weights->transpose_weight) { + weight->w_lang->layers[layer].attn_o_proj = + weight->w_lang->layers[layer].attn_o_proj->permute({1,0}); //[nh/ndev*dh, d] + } + } +} + +// --- MLP +void load_mlp_norm(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading mlp norm " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].mlp_norm->load(cpu_ptr, weight->load_stream); + } +} + +void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading mlp gate " << layer << " from " << cpu_ptr << std::endl; + int ndev = int(weights->device_weights.size()); + auto di = weights->meta->text_meta.head_dim; + auto d = weights->meta->text_meta.hidden_size; + // [ndev, 2*di // ndev, d] + for (int idev = 0; idev < ndev; idev++) { + auto weight = weights->device_weights[idev]; + size_t offset = idev * (2 * di / ndev) * d * dsize(weights->meta->dtype); + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].mlp_gate_up->load((char *)cpu_ptr + offset, weight->load_stream); + if(weights->transpose_weight) { + weight->w_lang->layers[layer].mlp_gate_up = + weight->w_lang->layers[layer].mlp_gate_up->permute({1,0}); //[d, 2*di/ndev] + } + } +} + +void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading mlp down " << layer << " from " << cpu_ptr << std::endl; + int ndev = int(weights->device_weights.size()); + auto di = weights->meta->text_meta.head_dim; + auto d = weights->meta->text_meta.hidden_size; + //[ndev, d, di // ndev] + for (int idev = 0; idev < ndev; idev++) { + auto weight = weights->device_weights[idev]; + size_t offset = idev * d * (di / ndev) * dsize(weights->meta->dtype); + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_lang->layers[layer].mlp_down->load((char *)cpu_ptr + offset, weight->load_stream); + if(weights->transpose_weight) { + weight->w_lang->layers[layer].mlp_down = + weight->w_lang->layers[layer].mlp_down->permute({1,0}); //[di/ndev, d] + } + } +} + +// --- Vision weights +void load_patch_embed_weight(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading patch embed weight from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->patch_embed_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_patch_embed_bias(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading patch embed bias from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->patch_embed_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_pos_embed_weight(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading pos embed weight from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->pos_embed_weight->load(cpu_ptr, weight->load_stream); + } +} + +// Vision block attention +void load_attn_proj_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision attn proj weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].attn_proj_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_proj_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision attn proj bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].attn_proj_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_qkv_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision attn qkv weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].attn_qkv_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_attn_qkv_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision attn qkv bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].attn_qkv_bias->load(cpu_ptr, weight->load_stream); + } +} + +// Vision block mlp +void load_mlp_linear_fc1_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision mlp fc1 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].mlp_linear_fc1_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_mlp_linear_fc1_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision mlp fc1 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].mlp_linear_fc1_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_mlp_linear_fc2_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision mlp fc2 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].mlp_linear_fc2_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_mlp_linear_fc2_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision mlp fc2 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].mlp_linear_fc2_bias->load(cpu_ptr, weight->load_stream); + } +} + +// Vision block norm +void load_norm1_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision norm1 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].norm1_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_norm1_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision norm1 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].norm1_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_norm2_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision norm2 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].norm2_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_norm2_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading vision norm2 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->blocks[layer].norm2_bias->load(cpu_ptr, weight->load_stream); + } +} + +// Deepstack merger +void load_deepstack_merger_linear_fc1_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger fc1 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].linear_fc1_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_deepstack_merger_linear_fc1_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger fc1 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].linear_fc1_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_deepstack_merger_linear_fc2_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger fc2 weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].linear_fc2_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_deepstack_merger_linear_fc2_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger fc2 bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].linear_fc2_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_deepstack_merger_norm_weight(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger norm weight " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].norm_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_deepstack_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { + std::cout << "Loading deepstack merger norm bias " << layer << " from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->deepstack_mergers[layer].norm_bias->load(cpu_ptr, weight->load_stream); + } +} + +// Merger +void load_merger_linear_fc1_weight(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger fc1 weight from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->linear_fc1_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_merger_linear_fc1_bias(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger fc1 bias from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->linear_fc1_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_merger_linear_fc2_weight(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger fc2 weight from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->linear_fc2_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_merger_linear_fc2_bias(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger fc2 bias from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->linear_fc2_bias->load(cpu_ptr, weight->load_stream); + } +} + +void load_merger_norm_weight(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger norm weight from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->norm_weight->load(cpu_ptr, weight->load_stream); + } +} + +void load_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr) { + std::cout << "Loading merger norm bias from " << cpu_ptr << std::endl; + for (int dev = 0; dev < int(weights->device_weights.size()); dev++) { + auto weight = weights->device_weights[dev]; + RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); + weight->w_vis->merger->norm_bias->load(cpu_ptr, weight->load_stream); + } +} + + +static Qwen3vlWeightLoader weight_loader = { + // Language model loaders + .lang_loader = { + .load_input_embd = load_input_embd, + .load_output_norm = load_output_norm, + .load_output_embd = load_output_embd, + .load_attn_norm = load_attn_norm, + .load_attn_q_norm = load_attn_q_norm, + .load_attn_k_norm = load_attn_k_norm, + .load_attn_qkv_proj = load_attn_qkv_proj, + .load_attn_o_proj = load_attn_o_proj, + .load_mlp_norm = load_mlp_norm, + .load_mlp_gate_up = load_mlp_gate_up, + .load_mlp_down = load_mlp_down, + }, + // Vision model loaders + .vis_loader = { + .load_patch_embed_weight = load_patch_embed_weight, + .load_patch_embed_bias = load_patch_embed_bias, + .load_pos_embed_weight = load_pos_embed_weight, + .load_attn_proj_weight = load_attn_proj_weight, + .load_attn_proj_bias = load_attn_proj_bias, + .load_attn_qkv_weight = load_attn_qkv_weight, + .load_attn_qkv_bias = load_attn_qkv_bias, + .load_mlp_linear_fc1_weight = load_mlp_linear_fc1_weight, + .load_mlp_linear_fc1_bias = load_mlp_linear_fc1_bias, + .load_mlp_linear_fc2_weight = load_mlp_linear_fc2_weight, + .load_mlp_linear_fc2_bias = load_mlp_linear_fc2_bias, + .load_norm1_weight = load_norm1_weight, + .load_norm1_bias = load_norm1_bias, + .load_norm2_weight = load_norm2_weight, + .load_norm2_bias = load_norm2_bias, + .load_deepstack_merger_linear_fc1_weight = load_deepstack_merger_linear_fc1_weight, + .load_deepstack_merger_linear_fc1_bias = load_deepstack_merger_linear_fc1_bias, + .load_deepstack_merger_linear_fc2_weight = load_deepstack_merger_linear_fc2_weight, + .load_deepstack_merger_linear_fc2_bias = load_deepstack_merger_linear_fc2_bias, + .load_deepstack_merger_norm_weight = load_deepstack_merger_norm_weight, + .load_deepstack_merger_norm_bias = load_deepstack_merger_norm_bias, + .load_merger_linear_fc1_weight = load_merger_linear_fc1_weight, + .load_merger_linear_fc1_bias = load_merger_linear_fc1_bias, + .load_merger_linear_fc2_weight = load_merger_linear_fc2_weight, + .load_merger_linear_fc2_bias = load_merger_linear_fc2_bias, + .load_merger_norm_weight = load_merger_norm_weight, + .load_merger_norm_bias = load_merger_norm_bias, + } +}; + +__C Qwen3vlWeights * +createQwen3vlWeights(const Qwen3vlMeta *meta, + infiniDevice_t device, + int ndev, + const int *dev_ids, + bool transpose_weight) { + auto weights = new Qwen3vlWeights(meta, device, ndev, dev_ids, transpose_weight); + return weights; +}; + +__C Qwen3vlWeightLoader * +createQwen3vlWeightLoader() { + return &weight_loader; +} \ No newline at end of file diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index edf0faeb..37d8712a 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -267,7 +267,7 @@ void print_data_bf16(uint16_t const *data, const std::vector &shape, std::cout << std::endl; } else if (dim < shape.size() - 1) { for (size_t i = 0; i < shape[dim]; i++) { - print_data(data + i * strides[dim], shape, strides, dim + 1); + print_data_bf16(data + i * strides[dim], shape, strides, dim + 1); } } } From 15ee467fb0f42d743227084ff14d32686447fc5c Mon Sep 17 00:00:00 2001 From: hejianlin <892082223@qq.com> Date: Fri, 26 Dec 2025 21:50:08 +0800 Subject: [PATCH 2/3] fix Multiple requests infer --- scripts/launch_server.py | 28 ++++++----------- scripts/qwen3vl.py | 2 +- scripts/qwen3vl_test.py | 2 +- scripts/test_perf.py | 4 +-- src/models/qwen3vl/qwen3vl.cpp | 34 ++++++++++----------- src/utils.hpp | 55 +++++++++++++++++++++++++++------- xmake.lua | 7 +++++ 7 files changed, 82 insertions(+), 50 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 2d231b49..01b2f60d 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -1,5 +1,4 @@ -from jiuge import JiugeForCauslLM -from jiuge_awq import JiugeAWQForCausalLM +from qwen3vl import Qwen3vlForCauslLM from libinfinicore_infer import DeviceType from infer_task import InferTask from kvcache_pool import KVCachePool @@ -60,14 +59,9 @@ def parse_args(): "--max-tokens", type=int, required=False, - default=None, + default=200, help="Max token sequence length that model will handle (follows model config if not provided)", ) - parser.add_argument( - "--awq", - action="store_true", - help="Whether to use AWQ quantized model (default: False)", - ) return parser.parse_args() @@ -76,7 +70,6 @@ def parse_args(): model_path = args.model_path ndev = args.ndev max_tokens = args.max_tokens -USE_AWQ = args.awq MAX_BATCH = args.max_batch print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." @@ -93,7 +86,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "jiuge", + "model": "qwen3vl", "system_fingerprint": None, "choices": [ { @@ -122,14 +115,9 @@ def output(self, out_token): @contextlib.asynccontextmanager async def lifespan(app: FastAPI): # Startup - if USE_AWQ: - app.state.model = JiugeAWQForCausalLM( - model_path, device_type, ndev, max_tokens=max_tokens - ) - else: - app.state.model = JiugeForCauslLM( - model_path, device_type, ndev, max_tokens=max_tokens - ) + app.state.model = Qwen3vlForCauslLM( + model_path, device_type, ndev, max_tokens=max_tokens + ) app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) app.state.request_queue = janus.Queue() worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) @@ -169,6 +157,8 @@ def worker_loop(app): batch.append(req) except queue.Empty: break + + print(f"infering {len(batch)} tasks") output_tokens = app.state.model.batch_infer_one_round(batch) for task, token in zip(batch, output_tokens): task.output(token) @@ -298,7 +288,7 @@ async def chat_completions(request: Request): if __name__ == "__main__": - uvicorn.run(App, host="0.0.0.0", port=8000) + uvicorn.run(App, host="0.0.0.0", port=8008) """ curl -N -H "Content-Type: application/json" \ diff --git a/scripts/qwen3vl.py b/scripts/qwen3vl.py index 12cc2d72..9f7468af 100644 --- a/scripts/qwen3vl.py +++ b/scripts/qwen3vl.py @@ -612,7 +612,7 @@ def test(): ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 model = Qwen3vlForCauslLM(model_path, device_type, ndev, max_tokens=1024) - model.generate("山东最高的山是?", 50) + model.generate("山东最高的山是?", 200) model.destroy_model_instance() diff --git a/scripts/qwen3vl_test.py b/scripts/qwen3vl_test.py index 893cad6e..8c58e637 100644 --- a/scripts/qwen3vl_test.py +++ b/scripts/qwen3vl_test.py @@ -17,7 +17,7 @@ # 设置生成配置以确保确定性生成 model.generation_config = GenerationConfig.from_pretrained("/home/user/workshop/Qwen3-VL-2B-Instruct/", trust_remote_code=True) model.generation_config.do_sample = False # 关闭采样以确保确定性 -model.generation_config.max_new_tokens = 50 +model.generation_config.max_new_tokens = 200 # 输入消息 - 结合文本和图像(这里仅保留文本示例) messages = [ diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..b1951186 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -30,8 +30,8 @@ NUM_REQUESTS = 10 CONCURRENCY = 5 -API_URL = "http://127.0.0.1:8000" -MODEL = "FM9G-7B" +API_URL = "http://127.0.0.1:8008" +MODEL = "qwen3vl" async def benchmark_user(client, semaphore, queue, results, user_id, verbose): diff --git a/src/models/qwen3vl/qwen3vl.cpp b/src/models/qwen3vl/qwen3vl.cpp index d784b15e..d4bdaa12 100644 --- a/src/models/qwen3vl/qwen3vl.cpp +++ b/src/models/qwen3vl/qwen3vl.cpp @@ -66,7 +66,6 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, assert(meta.text_meta.num_key_value_heads % ndev == 0); auto dtype = meta.dtype; - printf("meta dtype: %d \n",dtype); auto nlayer = meta.text_meta.num_hidden_layers; size_t nh = meta.text_meta.num_attention_heads / size_t(ndev); size_t nkvh = meta.text_meta.num_key_value_heads / size_t(ndev); @@ -92,6 +91,10 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_cpu = std::vector(nreq); + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_rope->slice(1, 0, nh); + auto k_buf = qkv_rope->slice(1, nh, nkvh); + //Prepare inputs auto batch_pos_ids = std::vector(ntok); size_t req_start = 0; @@ -130,12 +133,11 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, max_seq_len = std::max(max_seq_len, size_t(seq_len)); } - auto attn_score_buf = Tensor::buffer(dtype, {nh * max_qk_size}, rsrc.memory_pool); - auto attn_val_buf = Tensor::buffer(dtype, {nh, max_seq_len, dh}, rsrc.memory_pool); - auto rearrange_q_buf = Tensor::buffer(dtype, {nkvh, ngroup, max_seq_len, dh}, rsrc.memory_pool); - auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); - auto q_buf = qkv_rope->slice(1,0,nh); - auto k_buf = qkv_rope->slice(1,nh,nkvh); + auto qk_buf = Tensor::buffer(dtype, {nh * max_qk_size}, rsrc.memory_pool); + auto rearrange_q_buf = Tensor::buffer(dtype, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); + auto attn_val_buf = Tensor::buffer(dtype, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); @@ -174,18 +176,17 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, auto full_v_buff = caches[req]->v[idev][i]->slice(0,0,total_len)->permute({1,0,2});// [nkvh, total_len, dh] //self-attn - auto attn_score_req = attn_score_buf->slice(0,0,nh*seq_len*total_len)->view({nkvh, ngroup*seq_len, total_len}); - auto rearrange_q = rearrange_q_buf->slice(2,0,seq_len); - rearrange(rearrange_q,q); + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto attn_score_req = qk_buf->slice(0,0,nh*seq_len*total_len)->view({nkvh, ngroup*seq_len, total_len}); // [nkvh, ngroup * seq_len, dh] @ [nkvh, dh, total_len] = [nkvh, ngroup * seq_len, total_len] - linear(attn_score_req,rearrange_q->view({nkvh, ngroup * seq_len, dh}),full_k_buff,1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + linear(attn_score_req,rearrange_q_buf->slice(1, 0, ngroup * seq_len),full_k_buff,1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); // softmax auto qk_softmax = attn_score_req->view({nh, seq_len, total_len}); causalSoftmax(qk_softmax,qk_softmax); - auto attn_val_req = attn_val_buf->slice(1,0,seq_len)->view({nkvh, ngroup * seq_len, dh}); // [nkvh, ngroup * seq_len, total_len] @ [nkvh, total_len, dh] = [nkvh, ngroup * seq_len, dh] - linear(attn_val_req, attn_score_req, full_v_buff, 1.0, 0.0, nullptr, nullptr); - rearrange(o,attn_val_req->view({nkvh, ngroup, seq_len, dh})); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), attn_score_req, full_v_buff, 1.0, 0.0, nullptr, nullptr); + //printf("rearrage o; layer[%d]\n",i); + rearrange(o,attn_val_gemm->slice(2, 0, seq_len)); token_offset += seq_len; } linear(logits_in, o_buf, weights->w_lang->layers[i].attn_o_proj, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); @@ -217,7 +218,7 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, // sample and output if (idev == 0) { if (last_logits != nullptr) { - rmsnorm(logits_out, logits_in, weights->w_lang->out_norm, meta.text_meta.rms_norm_eps); + rmsnorm(logits_out, logits_in, weights->w_lang->out_norm, epsilon); auto last_logits_buf = Tensor::buffer(dtype, {ntok, dvoc}, rsrc.memory_pool); linear(last_logits_buf, logits_out, weights->w_lang->out_embd, 1.0, 0.0, nullptr, nullptr); RUN_INFINI(infinirtStreamSynchronize(stream)); @@ -231,9 +232,8 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, rmsnorm(logits_out->slice(0, req, 1), logits_in->slice(0, token_offset - 1, 1), weights->w_lang->out_norm, - meta.text_meta.rms_norm_eps); + epsilon); } - //logits_out->slice(0,0,nreq)->debug(); different from transformers linear(prob_buf, logits_out->slice(0, 0, nreq), weights->w_lang->out_embd, 1.0, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); diff --git a/src/utils.hpp b/src/utils.hpp index b0da9fff..17b35628 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -7,9 +7,37 @@ #include #include +#ifdef __linux__ +#include +#include + +inline void printStackTrace() { + void *buffer[100]; + int nptrs = backtrace(buffer, 100); + char **strings = backtrace_symbols(buffer, nptrs); + + if (strings == nullptr) { + perror("backtrace_symbols"); + return; + } + + fprintf(stderr, "Stack trace:\n"); + for (int i = 0; i < nptrs; i++) { + fprintf(stderr, "%s\n", strings[i]); + } + free(strings); +} +#else +// 在非Linux系统上的备用实现 +inline void printStackTrace() { + fprintf(stderr, "Stack trace not available on this platform\n"); +} +#endif + inline void assertTrue(int expr, const char *msg, const char *file, int line) { if (!expr) { fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line); + printStackTrace(); exit(EXIT_FAILURE); } } @@ -20,6 +48,7 @@ inline void assertTrue(int expr, const char *msg, const char *file, int line) { #define PANIC(EXPR) \ printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \ + printStackTrace(); \ exit(EXIT_FAILURE) #define RUN_INFINI(API) \ @@ -29,6 +58,7 @@ inline void assertTrue(int expr, const char *msg, const char *file, int line) { std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \ << " from " << __func__ \ << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + printStackTrace(); \ exit(EXIT_FAILURE); \ } \ } while (0) @@ -38,21 +68,26 @@ inline float f16_to_f32(uint16_t h) { int32_t exponent = (h >> 10) & 0x1F; // Extract the exponent uint32_t mantissa = h & 0x3FF; // Extract the mantissa (fraction part) + union { + uint32_t int_value; + float float_value; + } converter; + if (exponent == 31) { // Special case for Inf and NaN if (mantissa != 0) { // NaN: Set float32 NaN - uint32_t f32 = sign | 0x7F800000 | (mantissa << 13); - return *(float *)&f32; + converter.int_value = sign | 0x7F800000 | (mantissa << 13); + return converter.float_value; } else { // Infinity - uint32_t f32 = sign | 0x7F800000; - return *(float *)&f32; + converter.int_value = sign | 0x7F800000; + return converter.float_value; } } else if (exponent == 0) { // Subnormal float16 or zero if (mantissa == 0) { // Zero (positive or negative) - uint32_t f32 = sign; // Just return signed zero - return *(float *)&f32; + converter.int_value = sign; // Just return signed zero + return converter.float_value; } else { // Subnormal: Convert to normalized float32 exponent = -14; // Set exponent for subnormal numbers @@ -61,13 +96,13 @@ inline float f16_to_f32(uint16_t h) { exponent--; } mantissa &= 0x3FF; // Clear the leading 1 bit - uint32_t f32 = sign | ((exponent + 127) << 23) | (mantissa << 13); - return *(float *)&f32; + converter.int_value = sign | ((exponent + 127) << 23) | (mantissa << 13); + return converter.float_value; } } else { // Normalized float16 - uint32_t f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); - return *(float *)&f32; + converter.int_value = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); + return converter.float_value; } } diff --git a/xmake.lua b/xmake.lua index 598ac534..51ded2fe 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,8 +1,15 @@ local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") +add_rules("mode.debug") + target("infinicore_infer") set_kind("shared") + if is_mode("debug") then + add_ldflags("-rdynamic", "-g") --调用栈中显示函数名 + add_cxxflags("-g", "-O0", "-fno-omit-frame-pointer") --获得最佳调试信息 + end + add_includedirs("include", { public = false }) add_includedirs(INFINI_ROOT.."/include", { public = true }) From 39e2d77c92c2307af626cf94dada80ce08de47f1 Mon Sep 17 00:00:00 2001 From: hejianlin <892082223@qq.com> Date: Mon, 12 Jan 2026 17:12:36 +0800 Subject: [PATCH 3/3] add qwen3vl visual processing and visual pos_emb --- 010P00002405F02D94-1.jpg | Bin 0 -> 23513 bytes include/infinicore_infer/models/qwen3vl.h | 40 ++- qwen3vl_test.sh | 27 ++ scripts/infer_task.py | 5 +- scripts/launch_server.py | 17 +- scripts/libinfinicore_infer/__init__.py | 2 +- scripts/libinfinicore_infer/base.py | 1 + scripts/libinfinicore_infer/qwen3vl.py | 59 ++++ scripts/qwen3vl.py | 72 ++++- scripts/qwen3vl_test.py | 97 +++---- scripts/test.py | 2 + scripts/test_perf.py | 2 +- src/cache_manager/opcache_manager.hpp | 2 + src/models/inference_context.cpp | 34 +++ src/models/inference_context.hpp | 10 + src/models/qwen3vl/qwen3vl.cpp | 335 +++++++++++++++++++++- src/models/qwen3vl/qwen3vl_impl.hpp | 35 ++- src/models/qwen3vl/qwen3vl_weight.cpp | 30 +- t012ed7ed15c1fafc48.jpg | Bin 0 -> 114478 bytes wget-log | 11 + 20 files changed, 648 insertions(+), 133 deletions(-) create mode 100644 010P00002405F02D94-1.jpg create mode 100755 qwen3vl_test.sh create mode 100644 scripts/test.py create mode 100644 t012ed7ed15c1fafc48.jpg create mode 100644 wget-log diff --git a/010P00002405F02D94-1.jpg b/010P00002405F02D94-1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8097d8f426f7d66bebfc2b5218e00708f237d78c GIT binary patch literal 23513 zcmeIacU)7;w>Y{JdhdwzD$+@4q4(aq6h#OD0)!4C3RqE5iijvxrB`Xv6%;9o3QDii zR0L@%AfmvV1oe2n_uPAa@4nCPujhuDy=Jvpvu5^WvNE%Ghj*tSdLunUJqQMaL$Tlw z+FjabVBzH-;vekgAHXXkDGi~t3{8oNA^b(l35Os!1m}c@LvSRLKoKHP=ud3UPZbN* ziyv2gmT)LmTH2$*4u4O=i~YP4Y^bO@im4MIQRc5a}%_F zCX^`%ImlppQj~-;Y)^+UJ)uMih=^o zei43{hcEku`uc(!7(@`Mn}4XEd$75iHyV_F$XTI75THab+6D|iUw=POP$pa^7>|(O z%UU7+xV`e88CITNKj?%8kp}x>(U@R6Uz1=U4*Od@Bm^2jGlE1$C>ZMerkl(&$UijT zZzxj#Agm|W591e#12WXQ)_Cpa1RN;SxQ2%K>tp;dL9QVfcYuIwjt5|VFjMX^<1u)O zAy5Fb@q0#{52TD1E^|`Y%{&H~$bY zp#BS<3Usm8f3g#vr-S&4AZn1{>Z=5JZysVM!1y5yBk>9tO{b0r&&+Fhs5C?(c?i(KZDf%MdN01i7CAVzAvY=o~~%PEJliPEA2U z%|Jy-#lS*GO-;wb#>B+J#KguxO}KvC_6q-*!)U0eXlQBZX=&-1X=!Ph@fj`io(RML ziGtmS5Cb`A&P^g1KLlrh5i!7aC&1H7x!VR)2Ue262p2FIe-S}2I57z+894_;I|ya8kfy~jQHYv|P$vVV4!yH7r_dTLkhEp6=a*9UpZ9Vsx+`Yr z$dY?K|JLrSN29symud`4ox*0JhX?}%J*b50u)%1%bi4F`bejaP0JRlOyd@FKxJ=w- z>(5e0E1&P+at#<>q(?8w3NAue9d-Fk zqT5N%e9W}*Sr?2Q`v?y0W9k-qA}XP97xH{pjaJXlqqdt?J9x>~$wd(5?ncwg+s@l& z0n!I2gY({6_bl3hLv7C4C86959kIK;&qs3yi9$69HEFf7xd z(Ha6Bf?abK{Zys#py_>~QfGCWwEt*fOxZ3}++Ow+_E_K8;FFltER@04a%{3gc>8If zW@T5xAkXp*`0Dt>*`$u)F|%0*BCjm=jvbGl>=%X1Yo$tuW$#t+yh!DVZShmg{^q69 z`^A0-D`6maBt_hsC>DxOi@Hp~HJTDK?v!3LJ?kJ@wv0|_UtbC;@EoAM5c;}$B=DFZ zM~3E3brzTZ4tXsYi56^uFf9gdOR96Zvc%_PX9sjm@?$6$N^xk8+c0L(CIdOO0TUwT zvo;US;H<@N?78f@@I-Icw4>}vyc9y%g5AEd8hq;w* z%%{xY0rdAu;fZ(@m8m$~NJH>Ecue}e0D<|GeR zLtiAQ9Tk7bI;r7tZs#)?dZHOS_0-W37CfJ4rM{*3IY(7+Kin~uyn8Iv}A@P(j%z)UYf>J|>QiJ#UDJvK+^C?X+C>O3F z0-euKh8rCfJma!Ez)Ptid|sNhI^7t`$CJ)zow9~W1B>iQ9y&c>iwMu|f#!B>W8mYr zw+;&}FN>%j=lPmc^5Qbv>}&L=Yw|SI3G2>!PnpISqP&j^zRDGWN)g3Or!t7Ac&`Ju zOaYB{zz|@Ke4h~jfFZ=Ccz-0h><-=o9@0w0I^GQeXzjsHOj+?z%#Gu5vcfxC#qmwO zX}xoA2H9aD_m-y~7x*uCxjuu!4O#c2&+=VNj6eSpWPXd|fUzZ;)u_5Up>o{tr=gmTCbo^3hg+HhBb7bJzB^qLPGjI@+K{>+rAVZ8pZ5Gf z(+4uc629U5UVoo<^P{9h6Ki_AP&2&ufU zvPZaN)H?muCQpnG)@^-gT*wT%T+A}_ppKFY{V7bzAeeo{tgLLJslWOgWya@^nA7Uw zG__GH+A#tSG)50apz74YiVvL%mGnu>g<=c_G!JP{GxNB%74_kX&xZE3i{hT$8Z9dnXpdQPX zTP{1K)m)jWPHbX@)|NDK?wWjgo#XEsn#+r}%&oVDo+>Vf^kD=91e=UMgx?$f!qrV% z)7|uS@Y}@K1%BUhFJpzjTCCXh zn6D*_RqmtrY)*+S4Ygr3*1ZrO+_7C#7325)#p3HhU!&`u&%Y{9X!M=eoiOXZW$l0~ zz0+UupuhG*-5hek0UO<1+@f8_F|wt;s+LWr*CK#i5Nn^0`$oOg{NF{fSV(QNhCHs_P~JEW&G%5E;gUfiqMAX=hM9Gz$ybJ;ZADd>yNv7cw*ETt{E zvUsAn(7i2X=hDE9>D`1k< z0^1X(b|GsY$HRV0FFfWV)Umke)WhXlJ~G+soZZQDBi83{)~G1(cV%%7UZ!nbBAwgj z!3|23?t3>oHLK3SHj+$K@}ji5FBs4p?pD5S&bz0(3$5~Ccfy(?&8tgIn$t6MkNZAG z`z*-Uzg&2PcEpa0w?*P)?X|tfpZTtH%ujybj?TJ1akff6o9p$Jz)#x(z3(4fZ~ioX z+I9WNp~2c`{@AcdyLUAM&9(MBSG`9D=AS=_VpZ)stox#-=2qDJC`ZftmM>qb=kKD4 zazY~(^9C@LWh4V*p|l}04P(Bnlq2seSDJPxhMM0++}njz=j|Gu7Uu7am6xSHtSWW# z_VPtF`5gMPO=na*=D(I;nRR(NXJE&X2_Z`{5J~|<`z;T$`cW-lLdJP&*(sD7_2PG| z>PeI8ujH>sWnQkPV@+TExJ=X9&`tA1Zd>%hm9@zek`wbXU!!vxuxqqMNu#4X*(Cn* z>irT}>d~d<8zt3Y&z#T?q?|@vHaVDI^;9Ho#ZY|OFuXKW8F4g&BsF@F`$ill)cK3t zM#oe3{o6y4`^rR1aiw(bODYBlkBr-B+fTJxj$2DJXh*F*Q=Z-~wAN_fS4;T>AZ;;s>+%%j<4kC4Lr>m`@T_IX;|D4Lk4TmO^b-qIzCM zrgf?7^_b5Ww`Pd?jHP_f7#gMeuq{?plwH9&@nz{)+1FdSyU^)fi1^;+32p5Md@Yy+ zOrug~)ys~tLn*VvyO3$=25!gy>phCaVThwL%4cj{bvdtQW}Ou!yU%I}S3byb=!2@F z@`#Vv+5`Q%$({5assr6;Z(KI0=HMEU9sJ@~&HZ9J-FW!ritdJQWmCp3G#43rSmDDy z%cmERJ?3#g6%+#YRqyEq8i;zPO)l zLZiYdbinG+;1=7`3EH7UD|Q_YP71!kOX}0n#@~mEzK>@5w>ELrkEBF2uq+>EUm=;! ziXWo6($KUKeaXDM%HF?x@)GJxX8OZ|Phj*5->%JYempNa&DDNAw)8BsL znZ30Or5n3js;Qq%ON0qS)ld z#8<82(asB_hZ}lEKi7QAnhcdJ11rlq%gpuQr5dH@#ve{}ug-_4C^i@5297n{i=abe zCs)4dNzF$(E;)ohU>&k?e#eoTP?9#d1RqlF^N_gl%=ydVy0s%WUhW+Gde3&@b91@w zx{rFn_?(;bE@bq9z2nO1W!eD#mB4L@sHO*EPRl#~eI<)e)xwaZKebCIr> z7wQV$I3-|cn{U7HEh>F=mt@i~?%3xmJ16?Kg}?L1DA+aLcWyD>Uw`u}j`jrx_UO^l z=)2(|M0}`e^|7jysIo1WK7mM=7f+rotXm*Ii?)+=)IL~Woc!9H=c3(P89%ZS-Me-1 zw5EM%`SAS9?wr(xIx*5>Gez`-X-{Rh`t*k`+;bibddbIkTPVzWCm~{}{Da*AF?6Ba z=IhKe8D~9esW^KaQgG$R$`X8>wm4&>Z|y>))!qw-t%Zvm!kD~|4VMNyK`(}yo$9c? z>S21dhJgmFhP`|2&ZJOv)aUao+st00C`&aIkq;H3<|e71~F0?$1gYxcJK!e;DjLvvYdb~3Wg-W_*hfg1uP+8 zif>c}NaG;kBf9u9;ZBGf@%|Ol*S41Bwa5Co`yUI2kkSZ*k~B&QDbFjds3a|~B#n>1 zLPWP=f7x~${!>3Z_BPQk_-E{w6h`!e7qtJccECjcY6lGYdpqE;pE_Vl@L#$?By8{c zo9%xhW1t`ZM8^I^#{NXc{zS(9M8^I^#{NXc{zS(9M8^I^#{NXc{zS(9M8^I^#{NXc z{zS(9_mQ#P8R|oZ+S)D_R_1z!Cc1>k7`1^3*3TafcJ1N(LV~RHwRjIXIP#KpgMbqS z*g6Q}L$2uH08?u{8+^RRP*1uyU4Y5!VcKnDbcfKZYJ z;FNL41cUv7uyX+RJr)vx$FBmI+06$J!|~g!n1g_U0A|C(o_nx3foBiaz{Bo1KX<@` zw>iKa=Z=Tl0DLhl6u)1O$OOPKVOY#D0QUn}&^Hu^1u%YdCo>M?3ic2Z<2P~%hG5WM z0G0+Yb&!pfHh@(?WRKeOH`whrI0S5*1H6#7f536VRx;j$Xc1n7q9Tgd0CUV26A~g} z;fnTg4RYtz_QwUd`W=U$y*A^eAV$1zd4ZGV5sLEi5=coxc#c5&7vc||f4kxPc8}+? zC85t?tk3_#{U!Yu&OaM$Fa;ZFi4y<9xurr--FXON8~qC>cpC&+&p}Xa&u{Hv#g7-S zkdOc+DXC+}j!9xMXi2<7e_j8|@Qd@m2Y#zh5?|jh-|_0~?L6egI~5If5{3rx1_!vJ zF}xE0If(z)g1?#dn;qg-7!OPk#t&3w2Sypz&l7aFpF1`L>+i>l_51I7_Rzq4CsXXZy7uR z6yd>GPhLD)+scL)9U2sdhktC~C4neGl${yc4{<{R&_S@Z7zs8OD?=KPE@T9mLpG2D zbOb_!csmXXf{sBaq0>+dbQwy7u0v_iEhrZ%g6=|9P#tt1dIWVq&!Csk05k?oL2scC z&?>Y6eFNJI$zU`vCfI%$4@?Lq1_I3}m>NtMW&*QJYjyYP}oV>S=eP*5-bgN z2UY~Dgf+k(!Jfii!A4-yuzA=g*cO}!P7P;<+YvXTmuDv%nGI*@vko*=zKnn_wt z`jGS$=`86w85!ArGBGk$GIKIlvJkQuvQ)BCvKF#lvbSWL3ZlE z=_%+B&}-AX(nry!(>Kr$(XTTwF(4Ss8E_0&7>XI5GQ4LbV?4m9%jnK{jxmSvG2<)~ z5t9IuHWQlZEclhXo#`zzDYG!MA+tC0W#%&GZst`MMiyBXJC<;kbe0D!Q>;X+2UrbR zeOVJ&t5^qFxAt-E)7*#IcWK|uyL?yuwmF@*~-}l*tYib?$_P#v;XS;y8RRE z#O$K%mh8vav)J3&7de*o8$FU)Vt ze}=!De?ov#043ldkSx$DuqwzSXexM8utacFh(ZV@m-+@1f?9Mu1a-E?MTZ>`%33Zk0BTlMu;^YGiuNblgnH%+IXO95OdFzh*vRA#4$7(PT+#X8`xhawL>b6|JCI@CLoIodhq zJFYwFI$d{q>n!IS<2-m+_;C2)r$^Y2_#S!SLg(V@Qhk)Vv0)0C%uPQN%KeJ1J5+FARvjpw+|ojNyrUjKYa3~fwM%+Lkp z3%4#3U&LPQz9e%g_0mqPdu-Qb>B}jXzh6OL>54C{^{IlXiK*YyywV2JwbIKo zI5RF~e93gr?9I~1x_gW3*5zAUx3RZ}@95vD%|4Kwl0%vkmhñIrel~7$1r>y?2#N9jPtmT(0po-9nh03FqgH^^= zt<|XNyZ40dWz;a$#MZ)UkJql&dDTtTJJi2wFluORRBo(kl58rt&v!rl0n3AgX3FMs z&ATlTE$a^h9xgued^FqY(mM9|(Bu9#%eL-zROx9$7pkl7 znf$Zb=W@?$UdX+u>6Y)V?NR8df2sJg@s-M}=3e#QM}0bd9sP#=&j!o}UJcp~4i7mG zO%1ybzaQ}(`8XOj`fcpgILUbI1pP$HYtGl%lOmJlQ}R;}rgf&D&sfil&7x;Nya|5u z{q4DTH1AU8c;<@UBj4Yj*PDO2;I#1eL%@e`i!n<4cxQYU2aY>U zGfe(i=lKOoy5bKqNTU65QsJ%vQV2<9IU*2Fd@8d7;vygm4BtFg`XGeuF7vO zXDV$PppEgu8bkzPtRu{9&=EdpMR$I6HC~l)rEpvT4in5-9;VB!Z9o zg}8=G_yr66)Sv^t#ROpk2&+B3Mptn7DMXbY*t*vUT);22|LCosjPB@Pd;#Em5212* zv=qh{gTwfR1Ov|^@D3m-4fuJ8@0ZBmRR1Cqy5px1aIO{pz)SEy3;pap!Z(@VUsVAq zQ`7%h0*Cv>9UP(?1{(iAjgDaa2_GpdOfXoXqcOT+pm72}O$@<${M8A6YZ9LFx7J|5 z?u&n}2v7ON8la>dgmDc4=ZgIOeSZ#{f7~yDx8VmjMh6`AQsu{QppcM8Ng(BI@Ea)P zzz7lp`zfTQ_gGE+-LW3W|B)5H*Wz!iVD!4XhPeJKvpd)Y>mP)31wO>$Ts<*T-T@d- ze%@alszf*n1!@N7$^4E7-}Ehl{5`O~fbd_wR}$>q>w6_r*B~DZcoOwoeSvP5 ziXWjrW!>?z3jb%adxz{?egAhl@MA&=r~s9_;vYX%egr=+0eJ{NynI7{AY~L36s4u5 z5YpiK!Qk+341dvW{ji`nfU^mj%|K0I7#*x9CKzCma`NIxIR%hVASsH2X-6EXAPW+} zp=8D7WTZht02ts=GU7-@d2tzOgt!a>1(Kq;91c}Hxb>tKjv{6WTgqEVNtgf^o(4wFztBK_2#SeKU3v4*X_mB=2bVe}28}f3X zxqp!yekY;ievk+Q${PPgnczMG``;TWBaf1o6_=G(l$R#tzAHa(E90fQ*j;~7wxDcP? zYmou2#pgg4+LU7)D?efB^%>fwGRap1dYd zC#|KeDXpNXE3XK8TuV+{5hW`xt1Y94{}=+Z20k9~<5TFzC!3~k$PdUDi~rbA@^$s| z1dq4Gk0YA=YDlDt6#f&-7Yi5((4LIcf6+(y5KuC}27`(8_z$htn83gE6Xty-SM=T! zhsIC278W|x1RA@*dCd11hF{U0;&@6+oZTd==J$T8O-jHV};4F2b7Uh2QI zI@mSrUu^!@8S(#?%|9H8_5xjqpS-2`@vR_q*EA-GYd0b35*EH z5R*`lkWetwP|+~|zrOa~LXiT+|MnJ&HVC|31CDecir z5aqz{$OpeX;f3}t!UlZ^Mh1uo@a71-;-6Xp170x$2_qjulgToWAL)Aa(JCp6)#;CJ zcPd+_|KmjuM;z;gM=x5)*b~n*zsc^o*7yu+rAxanyk$F0yjjv1;!TXTBPEF-J^0io z&_1v1o3;Ga3Fiu!VC;Qax>M0c7=}-XSW*6KGY=k?tyE+?PO`Nh;ZPjCrDO8Sn))lt z!%rieDQ0@Q^vE15!!zcwNIv(las|QQv+jrD%X=@f3F54Z@MvEDBYv8qCKpYDYZwsm zC*1O#PVb~4kD|Zdv80drYSy~aO>Tt!<>N;B2YcG-F3MTTJ=H8pdW}u& z2y164>>SO~`N+X)EVxL*BF-2k(y#k5RyokTCIIDYSU=Ly|M1dBEAh4`WXz+k3EpJl zbaC>74$R}~Z23;-$Gork9ZZG+qYfmiA^Pv2T z$y%GtA(}m|)-;bcn8yNSX3mDUn-eo0J-U%o2hG)(XOG+NI4XxtFKOPFqoyaXn`_@M zzx+Dw2}fOmw=3&vU=8hRKSKilSe-Xg;`lBk;T|{kF~O|g-KqH?=i^~7^V&&NzC)}* zw@82cy0ve6hkVICWR~n#Mjf*;Yw}C=qqz(m6&(kc^z!@!i@hvwHif-TGclR<4!xPm zbzY_4dDh$Ms*wH}M12&(q5H-(%cP~gx8eJ-85DA?>JsIlFGi$qPG(SEu4(mOaahO| zskm)(?F}7^dMV~3t;s&W+T$k}Ol|OtsvUOtnrs(|(``*ZLMu*zn*ov=NxJ#>9 z_o7DQTa(IpR+mPlA?et8wR4HI(swcrE+y)p^gs@^$IHx*J@NR=Yv@%%5xlzLjB&YY zK*mz(KkhjA#Z>ixY}k~Dl(f~t;aZ_J5kdcyeglD7(#mT?HP3vmQ}zb!WEMp->Bdpa z&(GL=8Y=3xYZmub)=Iw`e1lZS*1f1fV(sGhSPk{7ZWltf(L18 z6}f4*zl_#T+HkoP`1wQ_%akKB-Wg+_vG<+3*2tJA57(22JX!i$Oif$_<{cO@j+0e- z*SQ$Z23a@-t}@%?t!=y|b4@}9o~{sF7B62evAx0=9vM6m^)~Zu*{w4RuEJxS!S8)Z zr`^MHvMD;fZ|G>KSkttP)ee*fA&rQ&!?wR!OK5NpNRCUzV(Jf99pvmd$UuIMELQ30 zTzyfp7I%M$t^r3ZnF_T{w+}-_a9mM7dP}eW<(1&kR1dYy61}eQHd}_GTDOfc<}u6Q z?fF50@Q(gEzA@dK7p}Nl$LVtSk0I{UoR=Rz39AvR>{21B)u4NhtsXIg*Y>b2iq<-W zC^_;pu+b8q9#-t8Y;k7Zg+zz!1{;wh4qPeMwtY^1R1=;#$yL!MAa+{0;MQ6hademN z$v%sO)0@rR_Z!WUXCByG+l3hQVg^McLcQ)56@2+5D;C%+l8;7Lbf}(}9Y0Ci-xt_Z z6Up)X-Ve$>}SeKe7On$Lk|lP%zj<$scJP0z&lq&3#_3A05v)=!(+EWji>FxItj zqBdW9zvXCrpYp&MhL(y0JsW{5+^)@*o%9Y&Q9f$*NsiaIddzQa*4ORgv;9kEay;BJ zpEZWx-!`d$k=1#>=Q4VD?DS_%o=5!3x0Z>%4JLRfD7{QnP85ssVJ5%N6M1(+e^~!p z^>zu-UH!1i0>ee`jY!av%{4!0{ki&{%XtZK8~eBg7=?qgDcg?eGxt9DP~NEj}jR;drxS#M;%qwBrR zw{KaFtATB-nC{c-+I{TYZ`~fz?F^2G?Bv3BA(n*uM z_u{6y1v=6m0si~ip1#(mav4@*`?-xr{9S8MeJ3a@^DDf^YF?9O)v2n@Nfb&uKn;9v zgYG~Xe%>l?lAemv2YDkFo$o=(>uB=E%iD^94h>xBw12 zhtSNe;1ukVxsGfeZCTkvQDwS$t8uLp%QcyuC(nsUB)sCUT6@brS~|Bcv(ot6+`^IW zjCqGbQ`cH_oI>13eJ!NV;8;#xOK*9)7^8$1LSOrE^#kXzkXDY|25N&!3HOD~;6%;z2WJhuzc7R=ZW)ZTiXBN z_34JseXti(vh9}&E@XzwuM;yR6D7>Bl!g>A`ErJ6Z-wnbZZ+B8`R+8?%!^r@9dNUP zTiFkvN8cEP*3r#0#_uOgBzryPDH<6^TWVu$7B8@eUg*(%Jdg3>^H%BP+#YC9gls5B zueaNchHnO5VOm=IW@LnNs(kfjyZic<#0j%i$G$?{70c1AdXpPmeyG0jmg@PaXS~ab z7v$wXI>lMRC|!M!dp#;}{D395YSkO*Q)0+D*c(wNKY~JtD!K zpC+NO;61xQeC(sGzg1a{z7^w_=ACm-$Hs=NuB5EA$P~K=_GykSi`6~tuSm_fKxKq6 zR&Rq)l#6EMizf%o25YPCZk8a@&N-hwm?%(4pC^AOx$wkL54ZSg@!TWYgN|t3{(!s7 zM*6A^To!%b^PDfFqC^bdb+giyaM(rXZ`ROVWIsDjcQdkPclFA+PHhNAyCATmL$idE9t2BeiObR zt9!mzlI)*oK7`=bnOz2UB%b$`hZ@6F_VE?39!X-ho4FO7eqEpYRM4?;&XWU~TzYZ7 zLLDc&inmN15y{z~Zd|Rlnykp-^|!1~jz)-6tCyb~-DM?fd&8 z&0?(}ulIVUq}s;h>onj4P1?+-?Q|@XsSfv^uOj@<--s}#t%E^4Rh?vP`$evX#+8WKV zF=5bt@24!kq5b_?tGP(1H?fd@Wk=Al4zWN-*$dQuf!U9ny;`*qXKvkKh#3t$&&jLY zeVw{URI{hZP*hlCk|jRVMCe*1o0&taKJvBXXx{n(l{$wijZh&|6GKs>n-oF^9mk%& zLA`y##wNhQUO)auZPLb(zn%J?)rT^%>(}M;LRu^5P8NJxiT!Zcb*(IKS^8F(#5j;jun=@e98EU4`Y?6A$;BR zXql*>Sg(~b=a615W_+Jn;KB>xU=o>`YQLJh;usNf>&yVf>lwlUfn$Q_GONJh`@FI* zB{K2jLPs*b#aPFr846WYB_<1^Y$$}3Ing0gjJj@D*MwTd3#TQ!&iOqqy%?0e#yZ39 z(3MmZZ*giSATM`{c%@mSseKo+mkTAyS9^LPDfwbWn%QEJ?$MC0r*FwP8-%M2yA4$j zY*AiGJMJ-l-*H*RkpAZNB6jbu5*J)NRJ)X%JdFA>^6DDmDI9I+6e_uC`8N5m{bt_# z-aR|lAEy@cR_{uRi?@U8#|0AiTSofANR}%{bKyZgq>UfdKDWuXZ%5uRf{oeIsJl zc*baOLs3iPiIv%jV9qZuJXGXnFEgyK>2uXS&frOjA@=BH8q_k1wNb1<8VcM{o>sF; zJg!=a}AC5X}}!IU_W{0ZvK7)P^q zBO-q}?SgikqF0{+s_5zg&g$n+()5>zNCYS>3`E}&!>#UA?#~U4W?y30w@2|k2o0=} zxO+%u*y$3TkbzX5=z7SLcUJbbnp&rza=E1Uk`fw25&0hrmRK87>Y87relIk@-jjt{ zWn!x%jAaJ3d|ncwi>ySM=7ab9Oset%BO4lnx5b?IbTCgo&gvM?hYInv zUZ8y9m|>_Pcr=k#ql4s-WkZpF>q!Ptr&o=~i?}U^YcF4H)HXR+H4AfDn_#o}-Whw< z7w6oMS}LlKoB@05Em`c$4-rJXITc{6+iY_}6#3C@rm#1uB|FDx;% zm(n?#c(t1mNiIsJY;Wk3;{3AbVPi+!{*o$BMRuFQ%Yj#g^Iutp8jz8+3mp>eR?c8F zEH2J{D)EBLXGu>H;Zm61UO%5seB{c#7FLd4iDlUZM}~B(bnE9^8eN8VdYX@TgjKGI z%XmIGoTogLSVBFJXS`4-YGh%~A8+{9j!`DA+Sb;r?pm@U!(qWqdD^t*`N+fmL7x@O zi-PEr7PY%i5w)91N#+VuQ5Xn}=yN`P;l3f}NwZ#+Ejdze8NJi~5_40MH>H(GaIRit&|(01?It0lvbe%x{lhidb4W9@qS)MA5D@Ys{$;(Ti(WATLL`;t@Pru_0n zGx?bjq4DFqciyl_pP9DmAScc^ov6W!hK95|qgF~btHG*%<&%WjyaI8|E)=G{#5(F9 zI;|sOR+w}m*B9wsTokJMrkSWHn0VVXbu%h`#=U7N%h+By>UMP3=P2zWi4V#{y)n}2~c%)^2q{Qqpbx4&R?Ytz>% literal 0 HcmV?d00001 diff --git a/include/infinicore_infer/models/qwen3vl.h b/include/infinicore_infer/models/qwen3vl.h index ea8e6eee..ee3d59a2 100644 --- a/include/infinicore_infer/models/qwen3vl.h +++ b/include/infinicore_infer/models/qwen3vl.h @@ -169,11 +169,16 @@ dropQwen3vlCache(const struct Qwen3vlModel *, /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void inferBatchQwen3vl(struct Qwen3vlModel *, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct Qwen3vlCache **caches, - const float *temperature, const uint32_t *topk, const float *topp, - uint32_t *output); + const uint32_t *tokens, uint32_t ntok, + void *pixel_values, uint32_t total_patches, + uint32_t *image_grid_thw, uint32_t num_images, + void *pixel_values_videos, uint32_t total_patches_videos, + uint32_t *video_grid_thw, uint32_t num_videos, + uint32_t patch_features, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); /// @brief 批次推理一轮,输出 output embedding 后的 logits /// @param tokens 输入 token 地址 @@ -185,21 +190,14 @@ inferBatchQwen3vl(struct Qwen3vlModel *, /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void forwardBatchQwen3vl(struct Qwen3vlModel *, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct Qwen3vlCache **caches, - void *logits); + const uint32_t *tokens, uint32_t ntok, + void *pixel_values, uint32_t total_patches, + uint32_t *image_grid_thw, uint32_t num_images, + void *pixel_values_videos, uint32_t total_patches_videos, + uint32_t *video_grid_thw, uint32_t num_videos, + uint32_t patch_features, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct Qwen3vlCache **caches, + void *logits); #endif // QWEN3VL_WEIGHTS_H - -// self, -// input_ids: torch.LongTensor = None, -// attention_mask: Optional[torch.Tensor] = None, -// position_ids: Optional[torch.LongTensor] = None, -// past_key_values: Optional[Cache] = None, -// inputs_embeds: Optional[torch.FloatTensor] = None, -// pixel_values: Optional[torch.Tensor] = None, -// pixel_values_videos: Optional[torch.FloatTensor] = None, -// image_grid_thw: Optional[torch.LongTensor] = None, -// video_grid_thw: Optional[torch.LongTensor] = None, -// cache_position: Optional[torch.LongTensor] = None, \ No newline at end of file diff --git a/qwen3vl_test.sh b/qwen3vl_test.sh new file mode 100755 index 00000000..c83e47aa --- /dev/null +++ b/qwen3vl_test.sh @@ -0,0 +1,27 @@ +#!/bin/bash +#SBATCH --job-name=test_job # 任务名 +#SBATCH --output=output_%j.log # 标准输出文件(%j 会替换成 job ID) +#SBATCH --error=error_%j.log # 标准错误输出文件 +#SBATCH --partition=nvidia # 分区名(机器系统默认分区是 nvidia) +#SBATCH --nodes=1 # 需要的节点数 +#SBATCH --ntasks=1 # 总任务数(通常 = 节点数 × 每节点任务数) +#SBATCH --cpus-per-task=8 # 每个任务需要的 CPU 核心数 +#SBATCH --gres=gpu:nvidia:4 # 请求 4 块 GPU(nvidia 是 Gres 类型) +#SBATCH --mem=32G # 请求的内存 + +# 需要用到计算资源的命令 +# 推荐使用 srun 启动主程序,自动绑定资源 +source /data/apps/env.sh +source /data/apps/miniforge3/etc/profile.d/conda.sh +conda activate py313 +export INFINI_ROOT=$HOME/.infini +export LD_LIBRARY_PATH=$INFINI_ROOT/lib:$LD_LIBRARY_PATH +export PATH="/data/apps/xmake/bin:/usr/local/cuda/bin:$PATH" + +export PYTHONPATH=$HOME/InfiniLM/scripts:$PYTHONPATH + +cd $HOME/InfiniLM + +#srun python scripts/qwen3vl_test.py +#srun python scripts/qwen3vl.py --nvidia /data/shared/models/Qwen3-VL-2B-Instruct +srun python scripts/launch_server.py --model-path /data/shared/models/Qwen3-VL-2B-Instruct --dev nvidia --ndev 4 \ No newline at end of file diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..aca61285 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -1,8 +1,9 @@ class InferTask: - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + def __init__(self, id, inputs, max_tokens, temperature, topk, topp, end_tokens): self.id = id self.finish_reason = None - self.tokens = tokens + self.inputs = inputs + self.tokens = inputs['input_ids'][0].tolist() self.max_tokens = max_tokens self.temperature = temperature self.topk = topk diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 01b2f60d..d5f8a18a 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -102,8 +102,8 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): # A wrapper for InferTask that supports async output queue class AsyncInferTask(InferTask): - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): - super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + def __init__(self, id, inputs, max_tokens, temperature, topk, topp, end_tokens): + super().__init__(id, inputs, max_tokens, temperature, topk, topp, end_tokens) self.output_queue = janus.Queue() print(f"[INFO] Create InferTask {self.id}") @@ -171,15 +171,18 @@ def worker_loop(app): def build_task(id_, request_data, request: Request): messages = request_data.get("messages", []) - input_content = request.app.state.model.tokenizer.apply_chat_template( - conversation=messages, + inputs = request.app.state.model.processor.apply_chat_template( + messages, + tokenize=True, add_generation_prompt=True, - tokenize=False, + return_dict=True, + return_tensors="pt", ) - tokens = request.app.state.model.tokenizer.encode(input_content) + inputs.pop("token_type_ids", None) + return AsyncInferTask( id_, - tokens, + inputs, request_data.get("max_tokens", request.app.state.model.max_context_len()), request_data.get("temperature", 1.0), request_data.get("top_k", 1), diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 3cd85f4f..0661d865 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -6,7 +6,6 @@ DeepSeekV3MetaCStruct, DeepSeekV3WeightsCStruct, DeepSeekV3WeightLoaderCStruct, - DeepSeekV3CacheCStruct, ) from .qwen3vl import ( Qwen3vlModel, @@ -42,5 +41,6 @@ "Qwen3vlWeightLoaderCStruct", "Qwen3vlVisWeightLoaderCStruct", "Qwen3vlLangWeightLoaderCStruct", + "Qwen3vlCacheCStruct", "ModelRegister", ] diff --git a/scripts/libinfinicore_infer/base.py b/scripts/libinfinicore_infer/base.py index bed65b2e..93ddf0f9 100644 --- a/scripts/libinfinicore_infer/base.py +++ b/scripts/libinfinicore_infer/base.py @@ -67,4 +67,5 @@ def _load_library(self): lib_path = os.path.join( os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" ) + print("loaded infini lib!") return ctypes.CDLL(lib_path) diff --git a/scripts/libinfinicore_infer/qwen3vl.py b/scripts/libinfinicore_infer/qwen3vl.py index 949ba228..bc405770 100644 --- a/scripts/libinfinicore_infer/qwen3vl.py +++ b/scripts/libinfinicore_infer/qwen3vl.py @@ -2,6 +2,7 @@ from ctypes import ( c_size_t, c_uint, + c_uint16, c_int, c_float, c_void_p, @@ -19,12 +20,14 @@ class TextMetaCStruct(Structure): ("head_dim", c_size_t), ("hidden_size", c_size_t), ("initializer_range", c_float), + ("_pad1", c_float), ("intermediate_size", c_size_t), ("max_tokens", c_size_t), ("num_attention_heads", c_size_t), ("num_hidden_layers", c_size_t), ("num_key_value_heads", c_size_t), ("rms_norm_eps", c_float), + ("_pad2", c_float), ("mrope_section", c_size_t * 3), ("rope_theta", c_size_t), ("vocab_size", c_size_t), @@ -38,6 +41,7 @@ class VisMetaCStruct(Structure): ("hidden_size", c_size_t), ("in_channels", c_size_t), ("initializer_range", c_float), + ("_pad1", c_float), ("intermediate_size", c_size_t), ("num_heads", c_size_t), ("num_position_embeddings", c_size_t), @@ -51,6 +55,7 @@ class VisMetaCStruct(Structure): class Qwen3vlMetaCStruct(Structure): _fields_ = [ ("dtype", DataType), + ("_pad_dtype", c_uint), ("text_meta", TextMetaCStruct), ("vis_meta", VisMetaCStruct), # Token ids @@ -178,6 +183,15 @@ def register_lib(cls, lib): POINTER(Qwen3vlModelCStruct), POINTER(c_uint), c_uint, + c_void_p, # pixel_values, + c_uint, # total_patches, + POINTER(c_uint), # image_grid_thw, + c_uint, # num_images, + c_void_p, # pixel_values_videos, + c_uint, # total_patches_videos, + POINTER(c_uint), # video_grid_thw, + c_uint, # num_videos, + c_uint, # patch_features, POINTER(c_uint), c_uint, POINTER(c_uint), @@ -192,6 +206,15 @@ def register_lib(cls, lib): POINTER(Qwen3vlModelCStruct), POINTER(c_uint), c_uint, + c_void_p, # pixel_values, + c_uint, # total_patches, + POINTER(c_uint), # image_grid_thw, + c_uint, # num_images, + c_void_p, # pixel_values_videos, + c_uint, # total_patches_videos, + POINTER(c_uint), # video_grid_thw, + c_uint, # num_videos, + c_uint, # patch_features, POINTER(c_uint), c_uint, POINTER(c_uint), @@ -222,6 +245,15 @@ def infer_batch( model, tokens, ntok, + pixel_values, + total_patches, + image_grid_thw, + num_images, + pixel_values_videos, + total_patches_videos, + video_grid_thw, + num_videos, + patch_features, req_lens, nreq, req_pos, @@ -235,6 +267,15 @@ def infer_batch( model, tokens, ntok, + pixel_values, + total_patches, + image_grid_thw, + num_images, + pixel_values_videos, + total_patches_videos, + video_grid_thw, + num_videos, + patch_features, req_lens, nreq, req_pos, @@ -250,6 +291,15 @@ def forward_batch( model, tokens, ntok, + pixel_values, + total_patches, + image_grid_thw, + num_images, + pixel_values_videos, + total_patches_videos, + video_grid_thw, + num_videos, + patch_features, req_lens, nreq, req_pos, @@ -260,6 +310,15 @@ def forward_batch( model, tokens, ntok, + pixel_values, + total_patches, + image_grid_thw, + num_images, + pixel_values_videos, + total_patches_videos, + video_grid_thw, + num_videos, + patch_features, req_lens, nreq, req_pos, diff --git a/scripts/qwen3vl.py b/scripts/qwen3vl.py index 9f7468af..51433ebc 100644 --- a/scripts/qwen3vl.py +++ b/scripts/qwen3vl.py @@ -15,7 +15,7 @@ ) from infer_task import InferTask, KVCache -from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref, c_bool +from ctypes import POINTER, c_float, c_int, c_uint, c_uint16, c_void_p, byref, c_bool import os from pathlib import Path import safetensors @@ -25,7 +25,6 @@ import math import torch import transformers - torch.set_default_device("cpu") @@ -451,10 +450,64 @@ def __init__(self, tasks: List[InferTask]): self.topks = (c_uint * self.nreq)(*self.topks_list) self.topps = (c_float * self.nreq)(*self.topps_list) + # initialize visual encoder inputs + self.pixel_values = None + self.total_patches = 0 + self.image_grid_thw = None + self.num_images = 0 + self.pixel_values_videos = None + self.total_patches_videos = 0 + self.video_grid_thw = None + self.num_videos = 0 + self.patch_features = 0 + + # Prepare visual encoder inputs + all_pixel_values = [t.inputs['pixel_values'] for t in tasks if 'pixel_values' in t.inputs] + all_image_grid_thw = [t.inputs['image_grid_thw'] for t in tasks if 'image_grid_thw' in t.inputs] + all_pixel_values_videos = [t.inputs['pixel_values_videos'] for t in tasks if 'pixel_values_videos' in t.inputs] + all_video_grid_thw = [t.inputs['video_grid_thw'] for t in tasks if 'video_grid_thw' in t.inputs] + + if all_pixel_values: + concat_pixel_values = torch.cat(all_pixel_values, dim=0) # (total_patches, features) + self.total_patches = concat_pixel_values.shape[0] + self.patch_features = concat_pixel_values.shape[1] + self.flat_pixels = concat_pixel_values.flatten().to(torch.bfloat16).contiguous() + self.pixel_values = self.flat_pixels.ctypes.data_as(c_void_p) + + if all_image_grid_thw: + concat_grid_thw = torch.cat(all_image_grid_thw, dim=0) # (total_images, 3) + self.num_images = concat_grid_thw.shape[0] + flat_grid = concat_grid_thw.flatten().to(torch.int32).contiguous() + self.image_grid_thw = (c_uint * len(flat_grid))(*flat_grid.tolist()) + + if all_pixel_values_videos: + concat_pixel_values_videos = torch.cat(all_pixel_values_videos, dim=0) # (total_patches_videos, features) + self.total_patches_videos = concat_pixel_values_videos.shape[0] + self.patch_features_videos = concat_pixel_values_videos.shape[1] + print(self.patch_features_videos, flush=True) + self.flat_pixels_videos = concat_pixel_values_videos.flatten().to(torch.bfloat16).contiguous() + self.pixel_values_videos = self.flat_pixels_videos.ctypes.data_as(c_void_p) + + if all_video_grid_thw: + concat_grid_thw_videos = torch.cat(all_video_grid_thw, dim=0) # (total_videos, 3) + self.num_videos = concat_grid_thw_videos.shape[0] + flat_grid_videos = concat_grid_thw_videos.flatten().to(torch.int32).contiguous() + self.video_grid_thw = (c_uint * len(flat_grid_videos))(*flat_grid_videos.tolist()) + + def input_args(self): return ( self.tokens, self.ntok, + self.pixel_values, + self.total_patches, + self.image_grid_thw, + self.num_images, + self.pixel_values_videos, + self.total_patches_videos, + self.video_grid_thw, + self.num_videos, + self.patch_features, self.req_lens, self.nreq, self.req_pos, @@ -483,6 +536,7 @@ def __init__( self.meta = Qwen3vlMeta( config, max_tokens=max_tokens ) + self.processor = transformers.AutoProcessor.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) else: raise ValueError("Unsupported model architecture") @@ -530,16 +584,17 @@ def batch_infer_one_round(self, tasks: List[InferTask]): return list(output) def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): - input_content = self.tokenizer.apply_chat_template( - conversation=[{"role": "user", "content": input_content}], + inputs = self.processor.apply_chat_template( + conversation = [{"role": "user","content": [{"type": "text", "text": input_content}]}], + tokenize=True, add_generation_prompt=True, - tokenize=False, + return_dict=True, + return_tensors="pt", ) - tokens = self.tokenizer.encode(input_content) infer_task = InferTask( 0, - tokens, + inputs, self.max_context_len(), temperature_, topk_, @@ -552,7 +607,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. total_time = 0 output_content = "" - print(tokens) + print(inputs['input_ids'][0].tolist(), flush=True) for step_i in range(max_steps): start_time = time.time() @@ -572,6 +627,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. print("\n") avg_time = total_time * 1000 / steps if steps > 0 else -1 + print(output_content, flush=True) print(f"Time per step: {avg_time:.3f}ms") infer_task._kv_cache.drop(self) diff --git a/scripts/qwen3vl_test.py b/scripts/qwen3vl_test.py index 8c58e637..354008c1 100644 --- a/scripts/qwen3vl_test.py +++ b/scripts/qwen3vl_test.py @@ -6,20 +6,19 @@ # 加载模型和processor # 修改为使用Qwen3VLForConditionalGeneration和AutoProcessor model = Qwen3VLForConditionalGeneration.from_pretrained( - "/home/user/workshop/Qwen3-VL-2B-Instruct/", + "/data/shared/models/Qwen3-VL-2B-Instruct/", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa", trust_remote_code=True ) -processor = AutoProcessor.from_pretrained("/home/user/workshop/Qwen3-VL-2B-Instruct/", trust_remote_code=True) +processor = AutoProcessor.from_pretrained("/data/shared/models/Qwen3-VL-2B-Instruct/", trust_remote_code=True) # 设置生成配置以确保确定性生成 -model.generation_config = GenerationConfig.from_pretrained("/home/user/workshop/Qwen3-VL-2B-Instruct/", trust_remote_code=True) +model.generation_config = GenerationConfig.from_pretrained("/data/shared/models/Qwen3-VL-2B-Instruct/", trust_remote_code=True) model.generation_config.do_sample = False # 关闭采样以确保确定性 model.generation_config.max_new_tokens = 200 -# 输入消息 - 结合文本和图像(这里仅保留文本示例) messages = [ { "role": "user", @@ -31,6 +30,21 @@ ] } ] +# messages = [ +# { +# "role":"user", +# "content":[ +# { +# "type":"image", +# "url": "/data/users/monitor1379/InfiniLM/010P00002405F02D94-1.jpg" +# }, +# { +# "type":"text", +# "text":"Describe this image." +# } +# ] +# } +# ] # 处理输入 inputs = processor.apply_chat_template( @@ -40,63 +54,30 @@ return_dict=True, return_tensors="pt", ) + inputs = {k: v.to(model.device) for k, v in inputs.items()} inputs.pop("token_type_ids", None) -print("Input token IDs:", inputs["input_ids"][0].tolist()) -print("Input text:", processor.decode(inputs["input_ids"][0])) - -# 获取输入信息用于逐token生成 -input_ids = inputs["input_ids"] -attention_mask = inputs["attention_mask"] +# for k,v in inputs.items(): +# print(k) +# print(v.shape) +# print(v.dtype) +# print(v) -# 记录开始生成时的总token数 -initial_length = input_ids.shape[1] -generated_tokens = [] -generation_times = [] +# 添加时间统计逻辑 +start_time = time.time() +generated_ids = model.generate(**inputs, max_new_tokens=200, output_attentions=False, return_dict_in_generate=True) +end_time = time.time() -# 逐token生成 -with torch.no_grad(): - current_input_ids = input_ids - current_attention_mask = attention_mask - - # 获取EOS token ID - eos_token_id = model.generation_config.eos_token_id - - for i in range(model.generation_config.max_new_tokens): - start_time = time.time() - - # 单步生成 - outputs = model( - input_ids=current_input_ids, - attention_mask=current_attention_mask, - ) - - # 获取下一个token - next_token_logits = outputs.logits[:, -1, :] - next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) - - # 检查是否达到结束条件 - if next_token_id.item() == eos_token_id: - break - - # 记录生成时间 - end_time = time.time() - generation_times.append((end_time - start_time) * 1000) # 转换为毫秒 - - # 添加到已生成的token中 - generated_tokens.append(next_token_id.item()) - - # 更新输入以包含新生成的token - current_input_ids = torch.cat([current_input_ids, next_token_id], dim=1) - current_attention_mask = torch.cat([current_attention_mask, torch.ones((current_attention_mask.shape[0], 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)], dim=1) +total_time = end_time - start_time +num_steps = len(generated_ids.sequences[0]) - len(inputs['input_ids'][0]) # 减去输入长度得到生成步骤数 +avg_time = (total_time / num_steps) * 1000 # 转换为毫秒 -# 计算平均生成时间 -if generation_times: - avg_generation_time = sum(generation_times) / len(generation_times) - print(f"生成的tokens: {generated_tokens}") - print(f"生成的文本: {processor.decode(generated_tokens, skip_special_tokens=True)}") - print(f"生成的token数量: {len(generated_tokens)}") - print(f"平均生成一个token的时间: {avg_generation_time:.3f} ms") -else: - print("未生成任何新token") \ No newline at end of file +generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids.sequences) +] +output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text[0]) +print(f"Time per step: {avg_time:.3f}ms") \ No newline at end of file diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 00000000..96385c38 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,2 @@ +if __name__ == "__main__": + print("testing") \ No newline at end of file diff --git a/scripts/test_perf.py b/scripts/test_perf.py index b1951186..474d7736 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -49,7 +49,7 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): start_time = time.time() stream = await client.chat.completions.create( model=MODEL, - messages=[{"role": "user", "content": question}], + messages=[{"role": "user","content": [{"type": "text", "text": question}]}], stream=True ) diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index a8d1cd99..dab9f68e 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -153,6 +153,7 @@ class LRUDescriptorCache { class CacheManager { public: DECLARE_OP_CACHE(Add) + DECLARE_OP_CACHE(Conv) DECLARE_OP_CACHE(Mul) DECLARE_OP_CACHE(RMSNorm) DECLARE_OP_CACHE(Gemm) @@ -167,6 +168,7 @@ class CacheManager { CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), + Conv_cache(capacity, DESTROY_FUNC(Conv)), Mul_cache(capacity, DESTROY_FUNC(Mul)), RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)), Gemm_cache(capacity, DESTROY_FUNC(Gemm)), diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index edf3fd96..5ac5ee1f 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -33,6 +33,40 @@ void InferenceContext::add(std::shared_ptr c, c->data(), a->data(), b->data(), stream)); } +void InferenceContext::conv(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr bias, + void *pads, + void *strides, + void *dilations, + size_t n) { + size_t key = CacheManager::createDescriptorKey(y, x, w, bias); + // Combine additional parameters into the key for unique identification + hash_combine(key, std::hash()(pads)); + hash_combine(key, std::hash()(strides)); + hash_combine(key, std::hash()(dilations)); + hash_combine(key, std::hash()(n)); + + infiniopConvDescriptor_t desc; + if (!cache_manager->getConvDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateConvDescriptor( + op_handle, &desc, y->desc(), x->desc(), w->desc(), + bias ? bias->desc() : nullptr, pads, strides, dilations, n)); + cache_manager->putConvDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetConvWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopConv( + desc, workspace, workspace_size, + y->data(), x->data(), w->data(), + bias ? bias->data() : nullptr, stream)); +} + void InferenceContext::mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 76671777..c19f40a3 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -19,6 +19,11 @@ struct InferenceContext { void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); + void conv(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr bias, + void *pads, void *strides, void *dilations, size_t n); void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); @@ -86,6 +91,11 @@ inline void add(std::shared_ptr c, std::shared_ptr a, std::share getInferenceContext().add(c, a, b); } +inline void conv(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, std::shared_ptr bias, + void *pads, void *strides, void *dilations, size_t n) { + getInferenceContext().conv(y, x, w, bias, pads, strides, dilations, n); +} + inline void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { getInferenceContext().mul(c, a, b); } diff --git a/src/models/qwen3vl/qwen3vl.cpp b/src/models/qwen3vl/qwen3vl.cpp index d4bdaa12..8a710108 100644 --- a/src/models/qwen3vl/qwen3vl.cpp +++ b/src/models/qwen3vl/qwen3vl.cpp @@ -48,20 +48,279 @@ void releaseDeviceResource(Qwen3vlDeviceResource &res) { res.comm = nullptr; } -//todo: -// pd分离 -// flashattn + batching -// triron跨平台 -// pageattn +inline std::shared_ptr get_custom_SinTable(const Qwen3vlMeta &meta, std::vector> &pos_ids ,uint32_t dim, size_t theta) { + // pos_ids shape:[seq, dim/2] , pos ids acting on each dim + auto unit = dsize(meta.dtype); + auto half_dim = dim/2; + size_t len = pos_ids.size(); + void *table = std::malloc(len * half_dim * unit); + + for (size_t i = 0; i (pos_ids[i][j]) / std::pow(theta, static_cast(j) / half_dim)); + if (meta.dtype == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dim + j] = f32_to_f16(_cos); + } else if (meta.dtype == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dim + j] = f32_to_bf16(_cos); + } else if (meta.dtype == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dim + j] = _cos; + } else { + std::cout << "unsupported data type" << std::endl; + exit(1); + } + } + } + auto shape = std::vector({len, half_dim}); + auto tensor = Tensor::weight(table, meta.dtype, shape); + std::free(table); + return tensor; +} +inline std::shared_ptr get_custom_CosTable(const Qwen3vlMeta &meta, std::vector> &pos_ids ,uint32_t dim, size_t theta) { + // pos_ids shape:[seq, dim/2] , pos ids acting on each dim + auto unit = dsize(meta.dtype); + auto half_dim = dim/2; + size_t len = pos_ids.size(); + void *table = std::malloc(len * half_dim * unit); + + for (size_t i = 0; i (pos_ids[i][j]) / std::pow(theta, static_cast(j) / half_dim)); + if (meta.dtype == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dim + j] = f32_to_f16(_cos); + } else if (meta.dtype == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dim + j] = f32_to_bf16(_cos); + } else if (meta.dtype == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dim + j] = _cos; + } else { + std::cout << "unsupported data type" << std::endl; + exit(1); + } + } + } + auto shape = std::vector({len, half_dim}); + auto tensor = Tensor::weight(table, meta.dtype, shape); + std::free(table); + return tensor; +} + +inline std::shared_ptr fast_pos_embed_interpolate(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, + uint32_t* grid_thw, uint32_t num_batch, uint32_t total_patches) { + auto dtype = meta.dtype; + auto num_position_embeddings = meta.vis_meta.num_position_embeddings; + auto hidden_size = meta.vis_meta.hidden_size; + auto merge_size = meta.vis_meta.spatial_merge_size; + auto num_grid_per_side = static_cast(sqrt(num_position_embeddings)); + + uint32_t total_pixels_offset = 0; + std::shared_ptr patch_pos_embeds = Tensor::buffer(dtype,{total_patches, hidden_size},rsrc.memory_pool); + auto pos_embed_weight = rsrc.weights->w_vis->pos_embed_weight; + + std::vector> pos_embeds(4); + for (uint32_t i = 0; i < num_batch; ++i) { + uint32_t t = grid_thw[i * 3]; + uint32_t h = grid_thw[i * 3 + 1]; + uint32_t w = grid_thw[i * 3 + 2]; + auto weight_array = std::vector(h*w*hidden_size); + auto weight_tensor = Tensor::buffer(dtype,{h*w, hidden_size},rsrc.memory_pool); + + // 计算插值索引和权重 + std::vector> indices(4); + std::vector> weights(4); + + auto linspace = [](float start, float end, uint32_t num_points) -> std::vector { + std::vector res(num_points); + for (uint32_t i = 0; i < num_points; ++i) { + res[i] = start + (end - start) * i / (num_points - 1); + } + return res; + }; + + auto h_idxs = linspace(0, num_grid_per_side - 1, h); + auto w_idxs = linspace(0, num_grid_per_side - 1, w); + + for (uint32_t ih = 0; ih < h; ++ih) { + for (uint32_t iw = 0; iw < w; ++iw) { + float h_idx_f = h_idxs[ih], w_idx_f = w_idxs[iw]; + uint32_t h_idx_floor = static_cast(floor(h_idx_f)), + w_idx_floor = static_cast(floor(w_idx_f)); + uint32_t h_idx_ceil = std::min(static_cast(ceil(h_idx_f)), num_grid_per_side - 1), + w_idx_ceil = std::min(static_cast(ceil(w_idx_f)), num_grid_per_side - 1); + + float dh = h_idx_f - h_idx_floor, dw = w_idx_f - w_idx_floor; + + indices[0].push_back((h_idx_floor * num_grid_per_side) + w_idx_floor); + indices[1].push_back((h_idx_floor * num_grid_per_side) + w_idx_ceil); + indices[2].push_back((h_idx_ceil * num_grid_per_side) + w_idx_floor); + indices[3].push_back((h_idx_ceil * num_grid_per_side) + w_idx_ceil); + + weights[0].push_back((1 - dh) * (1 - dw)); + weights[1].push_back((1 - dh) * dw); + weights[2].push_back(dh * (1 - dw)); + weights[3].push_back(dh * dw); + } + } + + // 查表并加权求和 + for (int j = 0; j < 4; ++j) { + pos_embeds[j] = Tensor::buffer(dtype,{h*w, hidden_size},rsrc.memory_pool); + // 使用索引和权重获取对应位置嵌入,并乘以权重 + for(size_t i = 0; i < h*w; i++){ + rearrange(pos_embeds[j]->slice(0,i,1),pos_embed_weight->slice(0,indices[j][i],1)); + } + for(size_t i = 0; i < h*w; i++){ + uint16_t w_value = f32_to_bf16(weights[j][i]); + for(size_t k=0; k < hidden_size; k++){ + weight_array[i*hidden_size + k] = w_value; + } + } + RUN_INFINI(infinirtMemcpyAsync(weight_tensor->data(), weight_array.data(), sizeof(uint16_t)*h*w*hidden_size, + INFINIRT_MEMCPY_H2D, rsrc.stream)); + mul(pos_embeds[j],pos_embeds[j],weight_tensor); + } + + // 合并四个方向的结果 + auto patch_pos_embed = pos_embeds[0]; // [h*w, hidden_size] + for (int j = 1; j < 4; ++j) { + add(patch_pos_embed,patch_pos_embed, pos_embeds[j]); + } + + // 对于视频帧数T>1的情况,重复patch_pos_embed T次 + if (t > 1) { + auto temp_patch_pos_embed = Tensor::buffer(dtype,{t,h*w,hidden_size},rsrc.memory_pool); + for(size_t i = 0; i < t; i++){ + rearrange(temp_patch_pos_embed->slice(0,i,1), patch_pos_embed); + } + patch_pos_embed = temp_patch_pos_embed; + } + printf("merge patch pos embed/n"); + fflush(stdout); + patch_pos_embed = patch_pos_embed + ->view({t, h/merge_size, merge_size, w/merge_size, merge_size, hidden_size}) + ->permute({0, 1, 3, 2, 4, 5}) + ->view({t*h*w, hidden_size}); //可能因为内存不连续无法再view + + rearrange(patch_pos_embeds->slice(0,total_pixels_offset,t*h*w), patch_pos_embed); + total_pixels_offset += t*h*w; + } + return patch_pos_embeds; +} + +inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, uint32_t* grid_thw, uint32_t num_batch, uint32_t total_patches) { + auto dtype = meta.dtype; + auto hidden_size = meta.vis_meta.hidden_size; + auto num_heads = meta.vis_meta.num_heads; + auto head_dim = hidden_size / num_heads; + auto merge_size = meta.vis_meta.spatial_merge_size; + + std::vector> pos_ids_table_y ( + total_patches, + std::vector(head_dim/4) + ); + std::vector> pos_ids_table_x ( + total_patches, + std::vector(head_dim/4) + ); + for (uint32_t b = 0; b < num_batch; ++b) { + uint32_t offset = b * 3; + uint32_t num_frames = grid_thw[offset + 0]; + uint32_t height = grid_thw[offset + 1]; + uint32_t width = grid_thw[offset + 2]; + + uint32_t merged_h = height / merge_size; + uint32_t merged_w = width / merge_size; + + // 遍历所有块和块内位置 + size_t patch_offset = 0; + for (uint32_t bh = 0; bh < merged_h; ++bh) { + for (uint32_t bw = 0; bw < merged_w; ++bw) { + for (uint32_t ih = 0; ih < merge_size; ++ih) { + for (uint32_t iw = 0; iw < merge_size; ++iw) { + uint32_t row = bh * merge_size + ih; + uint32_t col = bw * merge_size + iw; + // 如果是多帧,重复 num_frames 次 + for (uint32_t f = 0; f < num_frames; ++f) { + size_t dim_offset = 0; + for(;dim_offsetslice(1,0,head_dim/4),sin_y); + auto sin_x = get_custom_SinTable(meta,pos_ids_table_x,head_dim/2,10000); + rearrange(sin->slice(1,head_dim/4,head_dim/2),sin_y); + auto cos = Tensor::buffer(dtype,{total_patches,head_dim/2},rsrc.memory_pool); + auto cos_y = get_custom_CosTable(meta,pos_ids_table_y,head_dim/2,10000); + rearrange(cos->slice(1,0,head_dim/4),cos_y); + auto cos_x = get_custom_CosTable(meta,pos_ids_table_x,head_dim/2,10000); + rearrange(cos->slice(1,head_dim/4,head_dim/2),cos_y); + + return std::pair{sin,cos}; +} + +void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, InferRequest &req) { + void *pixel_values = req.pixel_values; + uint32_t total_patches = req.total_patches; + uint32_t *image_grid_thw = req.image_grid_thw; + uint32_t num_images = req.num_images; + void *pixel_values_videos = req.pixel_values_videos; + uint32_t total_patches_videos = req.total_patches_videos; + //uint32_t *video_grid_thw = req.video_grid_thw; + //uint32_t num_videos = req.num_videos; + //uint32_t patch_features = req.patch_features; + + auto dtype = meta.dtype; + auto d = meta.vis_meta.hidden_size; + auto channels = meta.vis_meta.in_channels; + auto patch_size = meta.vis_meta.patch_size; + auto temporal_patch_size = meta.vis_meta.temporal_patch_size; + //auto stream = rsrc.stream; + auto weights = rsrc.weights; + + auto image_tensor = Tensor::weight(pixel_values, dtype, {total_patches, channels*temporal_patch_size*patch_size*patch_size}); + auto video_tensor = Tensor::weight(pixel_values_videos, dtype, {total_patches_videos, channels*temporal_patch_size*patch_size*patch_size}); + auto hidden_states = Tensor::buffer(dtype, {total_patches, d, 1, 1, 1}, rsrc.memory_pool); + + std::vector pads = {0, 0, 0}; + std::vector strides = {static_cast(temporal_patch_size), static_cast(patch_size), static_cast(patch_size)}; + std::vector dilations = {1, 1, 1}; + conv(hidden_states, image_tensor, rsrc.weights->w_vis->patch_embed_weight, rsrc.weights->w_vis->patch_embed_bias, + pads.data(), strides.data(), dilations.data(), 3); + hidden_states = hidden_states->view({total_patches, d}); + + auto pos_embeds = fast_pos_embed_interpolate(meta,rsrc,image_grid_thw,num_images,total_patches); + add(hidden_states,hidden_states,pos_embeds); + + auto [sin, cos] = rot_pos_embed(meta,rsrc,image_grid_thw,num_images,total_patches); + + +} + +void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, InferRequest &req) { + const uint32_t *tokens = req.tokens; + uint32_t ntok = req.ntok; + const uint32_t *req_lens = req.req_lens; + uint32_t nreq = req.nreq; + const uint32_t *req_pos = req.req_pos; + struct Qwen3vlCache **caches = req.kv_caches; + const float *temperature = req.temperature; + const uint32_t *topk = req.topk; + const float *topp = req.topp; + uint32_t *output = req.output; + void *last_logits = req.logits; -void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, - uint32_t idev, uint32_t ndev, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct Qwen3vlCache **caches, - const float *temperature, const uint32_t *topk, const float *topp, - uint32_t *output, void *last_logits) { assert(meta.text_meta.num_attention_heads % ndev == 0); assert(meta.text_meta.num_key_value_heads % ndev == 0); @@ -256,15 +515,47 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, } } +void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, InferState &state, InferRequest &req) { + // infer vision + sync + if (req.num_images > 0 || req.num_videos > 0){ + inferDeviceBatchVision(meta, rsrc, idev, ndev, req); + + std::unique_lock lock(state.mtx_sync); + state.sync_cnt--; + if (state.sync_cnt == 0) { + state.cv_sync.notify_all(); + } else { + state.cv_sync.wait(lock, [&] {return state.sync_cnt == 0;}); + } + } + // infer text + inferDeviceBatchText(meta, rsrc, idev, ndev, req); +} + __C void inferBatchQwen3vl(struct Qwen3vlModel *model, const uint32_t *tokens, uint32_t ntok, + void *pixel_values, uint32_t total_patches, + uint32_t *image_grid_thw, uint32_t num_images, + void *pixel_values_videos, uint32_t total_patches_videos, + uint32_t *video_grid_thw, uint32_t num_videos, + uint32_t patch_features, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct Qwen3vlCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; + model->req.pixel_values = pixel_values; + model->req.total_patches = total_patches; + model->req.image_grid_thw = image_grid_thw; + model->req.num_images = num_images; + model->req.pixel_values_videos = pixel_values_videos; + model->req.total_patches_videos = total_patches_videos; + model->req.video_grid_thw = video_grid_thw; + model->req.num_videos = num_videos; + model->req.patch_features = patch_features; model->req.req_lens = req_lens; model->req.nreq = nreq; model->req.req_pos = req_pos; @@ -274,6 +565,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *model, model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->states[0].sync_cnt = model->dev_ids.size(); for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -292,11 +584,25 @@ inferBatchQwen3vl(struct Qwen3vlModel *model, __C void forwardBatchQwen3vl(struct Qwen3vlModel *model, const uint32_t *tokens, uint32_t ntok, + void *pixel_values, uint32_t total_patches, + uint32_t *image_grid_thw, uint32_t num_images, + void *pixel_values_videos, uint32_t total_patches_videos, + uint32_t *video_grid_thw, uint32_t num_videos, + uint32_t patch_features, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct Qwen3vlCache **kv_caches, void *logits) { model->req.tokens = tokens; model->req.ntok = ntok; + model->req.pixel_values = pixel_values; + model->req.total_patches = total_patches; + model->req.image_grid_thw = image_grid_thw; + model->req.num_images = num_images; + model->req.pixel_values_videos = pixel_values_videos; + model->req.total_patches_videos = total_patches_videos; + model->req.video_grid_thw = video_grid_thw; + model->req.num_videos = num_videos; + model->req.patch_features = patch_features; model->req.req_lens = req_lens; model->req.nreq = nreq; model->req.req_pos = req_pos; @@ -306,6 +612,7 @@ forwardBatchQwen3vl(struct Qwen3vlModel *model, model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->states[0].sync_cnt = model->dev_ids.size(); for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -348,9 +655,7 @@ void launchDevice(const Qwen3vlMeta &meta, std::shared_ptr break; } - inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, - req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + inferDeviceBatch(meta, *rsrc, idev, ndev, state, req); state.proceed = false; lock.unlock(); diff --git a/src/models/qwen3vl/qwen3vl_impl.hpp b/src/models/qwen3vl/qwen3vl_impl.hpp index 7e126f6c..76dd9d0d 100644 --- a/src/models/qwen3vl/qwen3vl_impl.hpp +++ b/src/models/qwen3vl/qwen3vl_impl.hpp @@ -12,7 +12,7 @@ #include #include -struct LayerWeight { +struct Qwen3vlLayerWeight { std::shared_ptr attn_norm; std::shared_ptr attn_qkv_proj; std::shared_ptr attn_q_norm; @@ -20,15 +20,16 @@ struct LayerWeight { std::shared_ptr attn_o_proj; std::shared_ptr mlp_norm; - std::shared_ptr mlp_down, mlp_gate_up; + std::shared_ptr mlp_gate_up; + std::shared_ptr mlp_down; }; -struct LanguageModelWeight { +struct Qwen3vlLanguageModelWeight { std::shared_ptr in_embd, out_embd, out_norm; - std::vector layers; + std::vector layers; }; -struct VisBlockWeight { +struct Qwen3vlVisBlockWeight { std::shared_ptr attn_proj_weight, attn_proj_bias, attn_qkv_weight, attn_qkv_bias; std::shared_ptr mlp_linear_fc1_weight, mlp_linear_fc1_bias, mlp_linear_fc2_weight, mlp_linear_fc2_bias; std::shared_ptr norm1_weight, norm1_bias, norm2_weight, norm2_bias; @@ -45,9 +46,9 @@ struct MergerWeight { }; -struct VisualEncoderWeight { +struct Qwen3vlVisualEncoderWeight { std::shared_ptr patch_embed_weight, patch_embed_bias, pos_embed_weight; - std::vector blocks; + std::vector blocks; std::vector deepstack_mergers; std::shared_ptr merger; }; @@ -55,8 +56,8 @@ struct VisualEncoderWeight { struct Qwen3vlDeviceWeights { std::shared_ptr sin_table,cos_table; - std::shared_ptr w_lang; - std::shared_ptr w_vis; + std::shared_ptr w_lang; + std::shared_ptr w_vis; infiniDevice_t device; int dev_id; infinirtStream_t load_stream; @@ -89,7 +90,10 @@ struct Qwen3vlDeviceResource { std::shared_ptr memory_pool; }; -struct InferState { +struct InferState { // qwen3vl namespace + inline static std::mutex mtx_sync; + inline static int sync_cnt; + inline static std::condition_variable cv_sync; std::mutex mtx; std::condition_variable cv_load, cv_start, cv_done; bool loaded = false; @@ -97,9 +101,18 @@ struct InferState { bool exit_flag = false; }; -struct InferRequest { +struct InferRequest { // qwen3vl namespace const uint32_t *tokens; uint32_t ntok; + void *pixel_values; + uint32_t total_patches; + uint32_t *image_grid_thw; + uint32_t num_images; + void *pixel_values_videos; + uint32_t total_patches_videos; + uint32_t *video_grid_thw; + uint32_t num_videos; + uint32_t patch_features; const uint32_t *req_lens; uint32_t nreq; const uint32_t *req_pos; diff --git a/src/models/qwen3vl/qwen3vl_weight.cpp b/src/models/qwen3vl/qwen3vl_weight.cpp index f525b30e..ce9bbba5 100644 --- a/src/models/qwen3vl/qwen3vl_weight.cpp +++ b/src/models/qwen3vl/qwen3vl_weight.cpp @@ -23,7 +23,7 @@ inline std::shared_ptr getOutEmbd( } inline void getLayerWeight( - const Qwen3vlMeta *meta,LayerWeight& layer, int ndev) { + const Qwen3vlMeta *meta, Qwen3vlLayerWeight& layer, int ndev) { auto nkvh = meta->text_meta.num_key_value_heads; auto nh = meta->text_meta.num_attention_heads; auto dh = meta->text_meta.head_dim; @@ -49,7 +49,7 @@ inline void getLayerWeight( inline void getVisualWeight( - const Qwen3vlMeta *meta, std::shared_ptr w_vis) { + const Qwen3vlMeta *meta, std::shared_ptr w_vis) { Qwen3vlVisMeta vis_meta = meta->vis_meta; auto patch_embed_shape = std::vector({vis_meta.hidden_size , vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size}); w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape); @@ -62,7 +62,7 @@ inline void getVisualWeight( w_vis->merger->linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size}); w_vis->merger->norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); - w_vis->blocks = std::vector(vis_meta.depth); + w_vis->blocks = std::vector(vis_meta.depth); for (size_t i = 0; i < vis_meta.depth; i++) { w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size,vis_meta.hidden_size}); w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); @@ -156,9 +156,8 @@ Qwen3vlWeights::Qwen3vlWeights( device_weights[dev]->device = device; device_weights[dev]->dev_id = dev_id; RUN_INFINI(infinirtStreamCreate(&device_weights[dev]->load_stream)); - - device_weights[dev]->w_lang = std::make_shared(); - device_weights[dev]->w_vis = std::make_shared(); + device_weights[dev]->w_lang = std::make_shared(); + device_weights[dev]->w_vis = std::make_shared(); device_weights[dev]->w_lang->in_embd = getInEmbd(meta); device_weights[dev]->w_lang->out_norm = getOutNorm(meta); @@ -166,7 +165,7 @@ Qwen3vlWeights::Qwen3vlWeights( device_weights[dev]->sin_table = getSinTable(meta); device_weights[dev]->cos_table = getCosTable(meta); - device_weights[dev]->w_lang->layers = std::vector(meta->text_meta.num_hidden_layers); + device_weights[dev]->w_lang->layers = std::vector(meta->text_meta.num_hidden_layers); for (size_t layer = 0; layer < meta->text_meta.num_hidden_layers; layer++) { getLayerWeight(meta, device_weights[dev]->w_lang->layers[layer], ndev); @@ -288,7 +287,7 @@ void load_mlp_norm(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { std::cout << "Loading mlp gate " << layer << " from " << cpu_ptr << std::endl; int ndev = int(weights->device_weights.size()); - auto di = weights->meta->text_meta.head_dim; + auto di = weights->meta->text_meta.intermediate_size; auto d = weights->meta->text_meta.hidden_size; // [ndev, 2*di // ndev, d] for (int idev = 0; idev < ndev; idev++) { @@ -306,7 +305,7 @@ void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { std::cout << "Loading mlp down " << layer << " from " << cpu_ptr << std::endl; int ndev = int(weights->device_weights.size()); - auto di = weights->meta->text_meta.head_dim; + auto di = weights->meta->text_meta.intermediate_size; auto d = weights->meta->text_meta.hidden_size; //[ndev, d, di // ndev] for (int idev = 0; idev < ndev; idev++) { @@ -624,6 +623,19 @@ createQwen3vlWeights(const Qwen3vlMeta *meta, int ndev, const int *dev_ids, bool transpose_weight) { + + printf("=== C++ createQwen3vlWeights ===\n"); + printf("sizeof(Qwen3vlTextMeta): %zu\n", sizeof(Qwen3vlTextMeta)); + printf("sizeof(Qwen3vlVisMeta): %zu\n", sizeof(Qwen3vlVisMeta)); + printf("sizeof(Qwen3vlMeta): %zu\n", sizeof(Qwen3vlMeta)); + printf("meta->dtype: %d\n", meta->dtype); + printf("meta->text_meta.hidden_size: %zu\n", meta->text_meta.hidden_size); + printf("meta->text_meta.num_hidden_layers: %zu\n", meta->text_meta.num_hidden_layers); + printf("meta->text_meta.vocab_size: %zu\n", meta->text_meta.vocab_size); + printf("meta->vis_meta.depth: %zu\n", meta->vis_meta.depth); + printf("device: %d, ndev: %d, dev_ids[0]: %d\n", device, ndev, dev_ids[0]); + fflush(stdout); + auto weights = new Qwen3vlWeights(meta, device, ndev, dev_ids, transpose_weight); return weights; }; diff --git a/t012ed7ed15c1fafc48.jpg b/t012ed7ed15c1fafc48.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4c762f67b4c221876663ede6efd9a7a11198fe14 GIT binary patch literal 114478 zcmb5VWmHt}7dAX}cStu#JLAyZ-Q6`XAT1>w64DX^Go*A4F$^v6qY)TDngNw=P(l$< z@#X)l^{(ghv(H-R!`b&b_da#+Ywzni|JMF(0Vs7ev^4-&SXcn9#|`jr1E2~Z!Y3dk zz$YRkAS5OxA|a!vAR{FuW2U8}q~~De;^JUsXXoJ;6XxLqJ!NMXkrx4pOUlT|a0@A_ zDN3n|Ny|w7?<82n#KdHzWK0wkOp?6pypsRV_HP(KMTq?kyB7zG3xG|9g+qn)Zxp}` z0N^|Z`xx&3Z`e4vSa<+@f=4IA|MC1^*W(bZNB6&V04WX@0Gkqr^6|j78O3b8mfIID zytGlz4z2e-~D=|fiT&?F5w?wZ-YFiUN&8rNp-@M7|Wh&YpiXFrq+;>dYzlr-$ zY8dgA1ha+uDt~sljW+$&3Ep`BZTox<@9gRcYWyhAy^Et=-=B$VI$-7|+%dv`2*{@& z*gi45e3qTMb?5JOHFm4x@kYJa)L6Zq$9H?}KYBjs{Tv&$d)IX3HU+Q_*hl`s>#HO} zIFx8zNkeMVQfrXJq^MVkFo*J3$#>^_e_IO!Uw)W;WzAvh&Ee}CVn+VqoA53C|G2=cW9mfzT0e*E6TP5o!gX)>ACq?QfQSM#xDrm;yxD7v>p8zgfMmzHco7leNs>XEyiRO{rbN)ANzi)Oh?9bf1Dwuz{Gi+_4Mos z-#qItH6(H&3N3ggw!lg;QXV142CYr_BQ(wzvF6v`z_H27bBJ2qfQ7QSR%;NxEf<)v zH(6UK9zSe_JdG+0|4_i+%9k7DS;rM{lNwEwYYRHJoAzFB4CjR z)h!H{m?nXyBvrnqz9Rn(iTWTW75Z%EXX>>$z51?PDE@Cl=def>?Vj!kOtN)hw=m5{ zA(TBeSyw)?bBp;q0esubB_-@Hc`(drccOy6zmsY7>?hBL*fSZ+Ql4&8=3SbPm(!H} z^i2J&wA>ySSCq?4MNjHJ-I#Ncl6lqX+zU%FV=_6p`Gwhad))*_hYy2-%luMQUr|Iv zi8_bQkdesmVk(^vJQOycb`eEYw*4j#}2BS7U=S0vXVTS_si~C$% z=2jlRf*jPKmcv+v_}wdOpTMXX1xcV&Ysm0Ls#u^QrAJuNGgJWTmFcUm`DGE9Bika| zt-;to2B>#viy|9VXZ@h%5=VPPqCk}kGym~RRkw2d>$4J?rndJ*9ejRlugw7LaaZCg>fTJIi^Rmx1H zeQP40p))UmmaEJ<**Q-6^gLhZ?#!YAvDh44JfZ-g`aOZg_$myjbg!rK<=(&qUma0cVNDd~xJ?;LNz9i8VCw-9Z8$&opm ze10XPds^EsY{9gnZ#-8}JxTd=IzKb2J0RK8-BS9-G~85}or7diKP14?N_9$lwKp?H z32QvN26m>KdA9sx=Lm}VwC@)DDSu(CJLk&xKwn%9G4JV-AAfBUw9MtKzu!0KI9?5^z-rWNcDYu{z8H0QhN+4`EdQCM^uifTD)x&y& zAJ@2*?Zet4H99R-`z>h0FZw`*4zb<(fZ!wfqzaS>~v2 z89jEm-GCIx2)65K1orlY`u=v){=Lw*cuWjo(&$FZeQKAhhVKHINJI2AzB3ssnU%M@ zHpuz5P1V#|FYt(l3mUC?Sw|l45iu)NU@AJ=mWPXYZ&>Z53(h1{7N5+1yXKBt${g;A zl504SCKAn)s8pf-P_SUEsJ%oI^hdQL`~3-r%rdF{FDmihECue>{Z7=yrrf4Akzd|c z{dXXv>6Kw)V4V*jXH-S688Z@HYsBkS8+E)WDr_^JeW?^dVJ!ddd{bv-R9iX8OI%I^ zojNx4N!M7y4G3R!e0?P+vJD9lLxbWucy>u>&$+=noJLciK5MRsK8*3 zpJ5&8m(6gQdS;PRshn5RFOCq;r(JA*rABzyvx_XHmNza;V-A9dRmLr6^?%9*=x|5- zDD&hqt8X2jRoJ$&34_Pd&ROMz863dNBY_N+2u6)+V+3jtzlbGHEeNAxzPn7mY$_W1 zE4MIxxtjT03ppnu?`v&MmSGqF!4nvMlM%W6?5y!*+n6l%Yoz7sU3A?tQ-Q>0CSqn3;j~gc~Db{Eu%zf*r!G+OA<)CT4f_;1ZRwcpBF|G%{CA zXrfJLBg|6OH#0Q~6)X&Izo!$Oc_yL5Ss(zPNa$yvu~uH&eqJ_r&t?&9mZuqD94J%u zJZQh|%n%fP|@> zlZ)53r*Mm)FV`j$E0;Y5>+fsCj)Fs@Y{~^{-f8IsnIPW9Mir`9Mc}Gjsqp_egn(S> z>*$}%gf&UjdA0Gyq>fL`Xm&y>oLE45O5n~2X6fcQOukqpclo}Et=w>M%k~e%Q5Rf| z_M2JTb;F@9NZ25CzFMYko`NmYTRTn0V8*V#^@tk@)`=0Sc_ZF;FI4{GiT_h+WE8su z@rK)UQP{>ppecd9`$^zhFvuFROCxZ><|(NvK0uGyx}a`QrW^*!UA8&BH@D|A7Z=ScqpB^Ws%MTA^o})HnqSJBpHp7XHg?y- z6ZUXk!#o9l@MhPD^b7oqNVIGaylV_wp1+St`Ta$~Cpa^>zHZjp#*H2PGX~^#6Ud)h z{_U4orrZ1dmjfgD%I^TvJ?6MV#$SItB#r-=O69P7wK099@4-eaqCP`RVg1^sKgU2< z;~AN)Iu>7vx@kFHW0f#<9`d=qY8IaX3gBTjZWUcdd-rQr^EKYdeqaCmZ4G51VKryt z*^jxRvmkRaw!p@tiCvoz9=4{J;8U$X{{YAzEQon?cp?80uj#vRdB-RH^$r0K5!T$_HEcT3`)PY1Fo>!pPp6@n%`t9lE7!XJ>c9TAH1*%JoGqO-3fUUI%asN$|1Dd-=mEJysA zQ~hGA;J0b9y!F>|j^{UW-{GayrunRb{TGl7Z|TehgKqqH9U2=tbXs3gS=np@kfZ7Ol1?^t6$%J)=Dj+&X6d)N^$)(x?)Nl6}4btJk-dp(vjc{@U@`cJ{SNpz)u> z#eun!4@U}KrnRr0Rl6g}^Xv=`@zJM|cbcpJrL5Fvt2%S5@B*%?h20ZfW=&X`vGk4j zY@D`A(3PWbtz@RioD6Q$qL$_%TgarHpm?Tz^e`EbZ z{q>cY61ENYdiFHFQaxON3+{$Ta+bOD!Mfgxc1X$~pMPA@-+GAAUBaCF5aik0>thh3 zJ865IUo{doh5wFD!i|+K4@f__+$-%^S0}4rIIDUV=7pv?*9V{S2&Go5x@{6y7fx2U zbvH5Xwlp&Qd0Q>ut4qh}_fD+yh054u(UvJd%pqTJ^fmg%<<~C^!XuRI#)pSLTTTXu zUo#cBA=fZ2F{u)_V$Hwopqn7L3%2NdGkY$4*L1GKV_G#ETyc-23(jVKU(=A{6y;I> zc7*Mi&$Rp~li%C)h25W?+#Yt~z+%~{1jQ5^T2~AX&dmkI;hY0@QaKHOuH+Z ze!cme)PQ2BF`ElWcO8(KAK?F`@Tu7kS}MJ`Z-kroiMyB5%yi3%YfNxcO2s~rABspT z+n_mI{c(r*(;THx%vR>;Y~w8EWA(1EA7b{YdfO31$rM2-C6kYH(2U)maphqY^O$6 zAl_DzjBAB?7puC+yfV<4uA?L4<#(YC=#lAGWT~cc)iqF+3z^F*{C3De{@{Nkqqi@0 z(ADV(vz?u~y9g5%sVJog(-%kQT2n3GN^vFwo8e>c3~NoI)(h*)M4U~8 z?P>BXOP47Jr?%E5z+wJUO#62;0W zanQ$xEP`Nu{*3D2fZseVnPd31T%*!z_2Ijesh3W09cz=m>Xz~$jO4J`@OW*XtxMtnb>R>_ zbDGpLu=vF(n2cF?j{+}&zjkon_b%b&NG12=d9zK9M(-Ap#OQ~OtzV;J`n(M~gW+FQ zV&Z9pLA2ROh1%!#f@`D|?$5+!Afmzx zI0sb45Awb5Snf(pbdC)2zlC3Y{#6c#t1Uf;aq;hGrQJ|=h|`M;iSf#*0g>8}SZ3V0 z7raLDxId=tmBkx53gKg6zBX(mGPry>IhVfgc69e?*vXIrjk6@`uVFQu(7>Is?Q_2n1R#mf8Z= zd-|~mYl}gi$jBum3unda<^W6@ct#=BUCmZNQRJ;rg(bO1l}+x1>|1d?RGD~j!;8>r z;3h;_3~ap^DYsL-`;-3cM{g2dMmbJN@TdgC%j)n{18!{sA#GV_aUN+=2NXLj{A1Jo zhb;ySTzgi|RS-`8ZsK09%G}|~+Nr8N-9eM5;SlzOUtNo<{+Tk7ICPu3Gu%u@P_8uW z$Jh^jc)bs;PQx!Rtk=*u)z()g668}F6>ZUeh9x%Ed|r3pPee1kAp~%<2I9Z2KZm}g z>>sO(kwyNJ_b(4yYOM1jDk9IO+gjqC;G2wXWo6w_CuXt+-rQ&6wZoSS-Bf5qXe19AR+n zbmxoZz(~7+fU_R2Y39{E`x2v?J>=96s}AYfNEqll|JmmuQXpkIe~YvWLun3HM%8KVFy3EA^bVLfsR3e+$VcFO3TXkSg8DkDT`q}&uYuAVgF2~O}D%#ir!(g zdB$^>G-hsn##PvyXV_8HYyrypihzGVNUL2i!g@SLV!Syw&wG)ogfm%&TU<1qnX)(e z%OOiYVV+%c=9i-HSYE`=c&`;eO!H4Z>=H#lGQ-E>5_8vUV7agcd+qnn^7eO8WvnlC zYNy?ZKhi)dg2y3<(g+nJWHd_#1J!3x8fU+P{Y?~v&JMMkSJ>Xom-WZB>9cDC_-84O zLEBqlZ)swi)MtUhFY-*Lrk~|MxgFgVl%6o2CS6f>fBn8W;f01*|4V~p$R=$T*}>8i z1b&kmgE|rCR7*QYa{S%M7r5Ccf*Bf#G*zG zl(vxQ?DG5T#lt_q+|Qq7Ricmyk)op79h|Rvi)A9xwP)*i(iWQFyA?nAKlW;H1-lQ# z#+O^_6}`Hpmij4eU4cWwq{Xf52Ehr@MN}hsqgk6z-uW=UMm83iB&ih%Ntu+iYpuWk z6hwqH$)rFyH4^v^nvHQBrH`vkg2L5uq&szyx?#Vp`Ap(TEcw58g1j01)#|=Ttj@oZ zijplX;Lx476KSBBo%xKFU#VRr5Ev{!E0q&^d?cF4SmG>fQ5L|j%Z+C$pQCO9S_wFB zNGU%e+^zPQBgf6!+0<7(p00hjf?zzx4XSA*!Q{o4aZWqK?#k{Ka4 zrVyj=%B_b*_5wR49j{gGd#CeQvu+nQt34y1= zfX9-wR=6Ooc+JxhbEpr0W2`}mk)xH!CFebV?R^tVI%`1epTzY1m*_DstbM*tt7~P_Wh_q`kY+{mr)vyVMrJ`~RLYxXjj8>NeD!xM zeY#$y_MFK2ChA#aH>Szr7|fD}b!8z2tX#t8wF)@-zIPxx8yZ$fKldj7g+*IbT#tdy z+MsbW)RW)yF{^#V%cs-TrfGvAGg-*J4@h+ zwytJ zuBSd{tme#A&`xtV6p>r!t9cq3YFn{#Y3z>VztTHyhSAR?P4k!B3+T3GffvjbYg5AI zqF-;L$F& zEj)bdoH7yhweWT4UcJQG!om=!XzwWrtz1VF*}J4EogQdYTTkdd<_GWfpZaJo_m-{< zof=c7Y?gjKpogr|f{AE{!aT2t&;VPWEzs1sAv>8koeJGgKcSgPw{p;p)EVxv262SF z6BEf+r9R?t85(zXi;=dqPr3!qrx;eizn()Fs~|kmZFFoQqVkE!({5bZ$lo$ka$U^g z9PuhQFY0Fsf|G>f>awFvZJBvhja5#0%qN+S3hxB_?B$YmrGG9homDJ@^?IL1u3UkveZ^saCC_{vmR zKh@dw3OU6YJf??n7cY?Tu5&1;x(V8kvyR&_E>4r2p|QUfBLyiNs_P2LNl3Rj{j9Fv zb}3+>p&dSBbcNGNyY!{lEBrwku9RteENVL*s7inF&?Dejn2}fuZ(mH$Oq$dO>+#xJ zK$!vbOpU`3Hr2`wLt<3*N-tl|N^#w|@T7-%R)kAiDs`ZNm-MPU_|U>ju1EQy?W1Hx z`U(Dg-uAM4FzM>_oT0(XC)cAVLZrB<pYi#H1TK~>G`AnTvG{yKe8Rs!W7pBi zhpEWH(7Z7xIhDw5>1|4cY zHKxB@viQKNK9FxSqZDc_BQ-&nGNVTol}4XKY|n@Cb|?!yt};n9sa1UBc|<7z*x1-u zxc_JONC*9o<^kYPQFG!_(r|&Eb12ht1M$F0Dq^ZeJamuj58>kwEIe$}&(ZqK4@Ca} z_%S&D00fE`$OjsR5K}NVT2D_&77B&ZP~4*K$_8Kkom>5e|EBfLANPkq4NAQkD4#4$ zkD%wg>{f#?i#*zZ<2)tPw(IoK9Y&dnfhVAEfXRwDpJNgtA2?{Ruba7Eg5d8b*~Xx) zkCVAP1+cfAsiyosi+9nVR~pZdTxt!?hyVvnu&Dm4$B=&xrk81 zj9hBSIZrN`2?}+Brw_LBfLWIw8zUJhErINv95j?@^o=~~fu51X!Rez2!Nk;-XLO9x z1v&S%`ygXk^(~7&l6KD>OK;`*%_xiNOLu~2@h?o=!E4yF4? zNZUn?53GCWC&#cjl(mf%{WO`JBpVC9#lWDSDej#YhNPsb7ontwK4$7=8 z%*1Wd_?c({#n5(WC{@ad`=}j17Or_yjYYL6aIDAiMN@_Fe$0P9=i9B0@^ z8s7YdkQ60?ajQ_Zfl}gOJ`QWdikEmHI%aUxW1I%pio6SEF_^5da!WeK)Vl(8<&)5zPo`&7k3?YLcw!P7aqUR2uE5 z`HP($G;~bX)*8$b`=RG}N^&nDAOehF=(P8@>@h@-H+Nb{F8M{Y~1Cu<1=7cYZ4c zaoe$ZP<6gUgT+0emJtay*TM=7p)6h4U~hh`3OXwiu_YCAh#hd1aS3%=HW)LAMuUM2 zdd_eJ)6h9TOQbqXhy?1B?2{hC-{%XAg1A*@N1ZX!OlpA3{guDiyX|Ft6 z+gl^(y>IgQC=(a3{f+W7oSd|KU*gQrEmq(8asBzlo{}&;@64Kj-eDt$=b2$vu(E2k zW?5~poS9^p|(MJvXm54-`W zdOosJR<)Kaz)}xkiBIAD5+?D#uoDP2^oH}H@w8?6T(7R~R?r^dA6&bg&xO9t^A+}1 zI288qCI`5#?cxWxyr5XMFQ~QfY|`If){nY`$$J@}Rk&@7LJ%0*|BO%p5HOT&7)K(p zTs}7bO1_YRbe&?5MQJea7N?joIV`;?ual)+klT|RInIQQ)*_MIL1ib5vrU+Pm@t@; z%Om_^s4v+(7Ccgfg_JK+;?-^HToDtfA_5$Y2@C3zyhiU@8yl@EFd7n^$0bSSA)yal zUeDW#P=<^hF9ml5NNIA%0zJHeKN&J!8C`Z#HHn$}v(D=vOX{s+)|ue``L zjGaX2R%yM&T2L#4Dz<$af$sQgRr2vC<6uyRFJp1LlAbT+?YfY8bqGMafh^G7U|I;p zZ>uZl6Z@-lT4;LAB`;||b(y`c31;;)e_X5+TSn74ElftbjPgP1AAk_3%0xw}PGg17 zdEkOq#v7OcVw|dLc?n^icpJ(DKFS_|ssCxO4bg}Q7?)Ly;1i!AWz<)T;PeV#L z(}obq{N+Vsjok-BIpmb=EDBI^c0alr5iDpWt5g8h=RxiPl7L&*OVFdO5g}F}z3uBX z6t0wdgK9EE$zjRG{>c3#JU5$Js+K1Dul3c@niBMQu|R?8M6d=|=e%m=Y(qynhE;Rb zS3K`qOa2D`t?|CF0_J{%zhYS6=HKc^D{GP|8>cQu~H&k?6Z4qsmdAhV~=+sOSZ-~#HUs!>Mx)|%v1*R z;o`<_;cT?_A&}0G&t1Iwd_|zd=i%87x);@5&OSPI_>2G#LA)MI$vAgIm;Mm{n$uTT zMH)n+g@#nt7P=(lL1c17qUCv+#{{Dh0aK3B@<`C#;Q&-^6sI@2@;@|5%IJ?}M^Kv1 zT?&eRMFF(ONN)E2Rc2keKMoUGI?N@PAF&%L5D9j!c|9CX6=V=6q&Pv_cJef{i}xkl zZl(3iJWWWQ&T_ES_+;fq(J2!~^B z@9v8qaD0AzkSxnIP@{J?SmJqeRI0GIO5b&6$xV7sn$@C8W{GrYo*nU5qohef;Qh3Z z)MqiDYnKCyY6Pvs<~Wx;WBu|v8@znvsa|knU3IFtb%>#^{(IY(jGGd+ZKxj&mAA`6 zc8LX=B~q8me^xOCZk!#hu6yORbsjR=idjkYSKNYdH;9v$Nn#A?u=AVi(G?$R7Z(=> zHAkxiH1-p{bLnIt7{NeQs<5J%k>i99bT@UtZ&j40xj#sY-3icPt3zK7 z$1x}00Q2aHi^tm}9HNxntkB~x5=pZP3Qx4CE(|^}T|tc81m~&f@_W%(7N86lB+44v z%?jk!c4=7qRJh34$B%JvE-u-p+b0OdN zqTLI;7N}|ntA-ahp4Lbuf9eONZ;3kXF7y&j^sVGfp3dz1*v}$w(56#=tdXg;I7^Z@ zlpq4C^IhSEQQBtiBijs1fBnSwPAL1@L}>m&p?92#WwY84+-h#xstV&WjF;<8>z>Nl zdOHUjS}1ng&lp?2)F0dCnx+auHRB?24c-9}h<|_$L-CD;lKU7daw5Lj+L1lfJ*L2z zqWy$QPZ@UiwGeG8i7#!=wT}~AirQj{wuZyuf0X}#*QR2@EdG28vrbA+X%0Mz zp~iW}(NNd$*>F!zULWSAhryz6$A@;NpCHRTrio*lO5|HJ^7$PkNR@iJ>XwQSp>)G# z^#g-T!Om{C;j5FC4$8Lw04;ai%oh{)$Gp2o2=k`am3OO4vOriEULTeE=@eU`bJih15`H!I!m+%VzcNq6`wDKt1=mCXwd~8#WrE|v8YNx-Y zhjs=)0?*uNWhX2@jNH@5ElkjU--|3(X{$&WtIf(5%~m`4k<+AS$!s`i8}gJ*qlg7# z>m+S$if_YL#)YbMd8Wc=VecgA1X?< z^85Y9$_p$IS*l5R<^X9PU>vQU3nPpI0s;}jV)azN&D=~ycuNac6fy-fHY=B(dfX-xv)-#LELC%MXPrr)`cXj^I)th|2Iyt6S9?7<+_@I z)zYw^HVXk9e#jNE)P22v3JH9y!+2vBOts)&mr0EPbv+bWOB9a?eKgJU-yt-d@ep>vIJeOz7C_QK@&Fb8VPRrUYe@|l=D1UmJO3;M)C;bpHfb6 z(^1cyloxV;r>m3{IA^Sk;2kRUOeQXXjk2?8=KcA=@DD%&g?i#TRB)|B6p)fH_R;Q) zUJ^xOCsqfB2Z|G;3%3So4HY z9=bsT1+|0e0=>qYoCAX6W6^7z=$hIo{xS3BB9G-@h7x@&vD{ldf`z*t$FN(fH+jtg#`QU%|Ah$`2Tl;;Eg%q ztS!p7@qS`acsr&(>*6InuSpygKuulU>m^j;cC#?&#Dk0fwQ0$u$!mFOp`x*IY9J!m zUc#+}Gm+I91THl5sv4aLp<3})qaX6V(%(|t0v-UpW9KzxmE)y(GBgo`%SkRyWaCDz z2co$6(JQ&4^JMuWk|B8#*=3MUP8!wNr;j=p8XV(v9Ci-bbhdh<7_P3Old2WK2Vti1 z66}UFkzq$)eTu1C()BJAf3J3tcaqiV!403OcgDiIN}ssDWeZm4JwOh1??Q^VI4Am}~#S)xfais&SDyuC-L{Q;OZFC>!oC&aQ@ z6isLCaYNFZqB>^BItA!XN-gYXYx|jQn;?!{oya!#Es!*=(lw90oL;SiXyzo<)_-<; z`6A;WKTNN7rVR2xs$sP}N4p?6#=hmm{Y+5Y>V1~lT>WcE_9PT54@sL}7*AFx;)HIT zv3ydvjaq774qO{$O5}adJj`eVX`NW~rqc)hLN?oH!YcaVp61fVY}Lp_-L(3KdX;JKp=&Zz z4fo@j#Bg0@&bC^M77+mGZ^436i6=UbeJIw ziq#*0@m_9GQdE;xvw-NDYWkm|i<9tcq529=t5PXD&1Z$hDSLKSA5DhE#B3%lH61#F z%!riC6@Q#OM>vW&_qKI?CZBU^U$05JvJ5U2cfT?y?R7F7^IAn4A;#R~-F)>r#!-3g^<#-&u`y#$jUSVZ{=BvWtLJW34-^p6f0c1x>8i3N}7$FdKGV zk4dw;U^Oox>?G;_EH_Zy(R?glP=T7$)I`|v2s_VeOwyqkWDJ>X$v>(m#-gm$Eyj=I z=;GSGx;H-g$eD0$#kV!5N?dw4Aj|&P02RtnV$F>BGXo}0OsozRBpg@u#)LT!s!f;T z*Z^4}mm>9Ffmk_>K4qf~dC~Qw6TCl51lo2|tL|4E_QDR^BCN6LV?&n(TL|hy2i}bA zF$)TIzzCLD^*fOKT{DozG`99}3Z&l}QvUrM=zi)-D%P}81xxe*L z01vgKE6MJStcA#@D{g;G4@>axPCK*lqmw8Fk&`l|%;hdo`%^mkz1cN4;=-Od2vu`| zb$luJ?gor$b`(@LbVN6_QQx$`#Lg0z1zK3L=M2n{GT;Q$ng2L@a4SW3;c{f+lw^v4?F_8WbW zf+a%v6rXd~!;1C9%vYpSr4{A}XuWX`d_iTd%qTZLUaxaEVK)cbQVX}2u@wM1l2uYhrh64b* znAu>zn>FOj1XfH_%(_id{k>nh9yD~&5kMJqip^2RirW17dprJ0;i#p2F8L%t{6wSL zitRUy`w@>Mpkv-po=R+MICIg&S!HsL&)!#C!$M&M(+H@~CXW+!CEnSJYnSm{hTs>b zPboo577j``)B){I+cc0oGSJ)a~gGF~=W*m|!B>Q0H? zVYI!>s*RM5^e6ePQ)U%Xze@k?9kpi0Z}+xhG@*tpd@t?;6v|k3A)P>I!>u=?M}Ykc z4cMP?dpB)b(i~o9$j@~WXYbL#Jk;vc)BG^;5I-*bD_-)z(U0%-!6Bq~U!; zU%bp1Fb0BwuFx;mRxTH-UX}{q`IT*sw6jwTDh5d_inG^mpWK=a-g2*{w=Fx7rOzm~ zA>*4NgNqEQ+6-j3RGTz3Kn5r`IcQUU6pha?(iwaf%?|5GDb-Dh9UtIA23)}UtVWh7 znD~d;GoBlMj{GsR$Eu7`yo9XD_z(Q}0-cv9{s(Byc%#}RYf$@|T)@};?*mC7X21$c z;Z&xJPcThU__y8S5`^cDQy&oy{B8drbyD)+O7mVv+3*#&eA+HD`n)U2%Sj z6R^pqnunXnl_=)Vv02K~3;guS-yvOY#JcmCIfg1USvm?&S4tL$8%Q^-xg&t^$}tg9_k z*s)Qv^EsGeQm$oakKE)EVb};Uy>Wh>A&#zt_CCZ#(InRhjvmAuN4rH1;f`A>6LPso z&Uqh@8p4w`Bgh*l=+}v+V-nz{n*-2p(!K+3QqbUJK1+{O6Mic_!CCflpg)$7JtMOj zfY$@}AK+C-TEBie)ySWxAEnfPHa{KK9q3@0*5JI2=F-?UWjxUkqUy%M@7tqj3-HiW zvlEI8u$FboH95cXsBhid!xD(zX&Va&k*t#?! z#c!wn?6Z8~Eq26K{4a|eJ3|yLB_AYvFCxTqbJZssc|bCO<`ya?XnSM}QvY*RNV{=v zKbCIT>;zE=|3~jHxRqO&wPibI?mrgway%;gzxFXk8?zxZx<_5Q|9)-*;IvU^(!cSz6I*)JzvY78bcvHrAijMfeDVjU`&Mpb zlttsfvuAi1SGn+vlRT7>x{t6{Qt*@|qc>jsr~88K(cr=Q>hn1@oHT>Q9rwkvewl-N z^2?|^!pqZYuWL=LqCD9mP{O^elL7!ow08C3Zb@WjNl7U?3qd=&K?_w!&1Nz@_nCh?!Y+G~ zN2|NZ5526sLO`Mq6FvTlAyM>}KWh(bw|k?RT1I5}MO}ps z^}SUKEh=0R01TGI9Wv+6R&_x^fjj3NGUr9!@EO)f#ZaK3)kpMp4mgA>k4$zRMJkyQ zv|UK>`QVFbzvMl;r$Dl=eD1F`LpgGj19`ijf2CT*VTe?6vy+F?oO^uz(tl+qbz?df z<0a)h`451SRQ&M|P#Flyuc^aj=i*_#9885Mg{2Q{s9wpkjqk15RL%Bm$;nb-Wh$kP zbI*B1WKLXQ6B_!}d33*7>#WY~%|71FCreGS;yFc8wYlTx!k#SFG1jt8q?Fl=Q}yTl zZETGIay^@m(hGT5RXD$KG4BqX6>!c+YcNagp3xrshiLB653f7JB^l2A2k6tF0sn|r z>wLBz`D@@i4}hQUEnUGHnPGc~dUJIX<_=|-$dDlcmcNB3JTRWM!U7w>(@t|>A7OVf(P$jm> z1f+VJhxS}8Qfmu6FxcZ)cmFMVm?2uz^?s9@fwXhKS-!;^Faza|wE7?s$JEwpVI^X? z!>nHuL3oY>BSp{c!eQf1gygNir^StP@LS{M{SP~)XBdv0NUnMg?ooHLI8m5Vkhy=9c1)nzk*k(!PRFyu=Bg!@9{tPk>5rsSlx^4F zhI_{+9XCe=?EvAnaXbW4-Et#G7pV`x)g=Ncls;Eauu#ItS)f%3(JNTug%3Cb6;|Vd zL)i*k!=bZ_cIXV)Ib%p%&s@VF{vVJE;~}T%h3+Nxo3aAcqeF@4PjiW@m{sP>w9U2kL;c%B7xTFdF*j_Z%)RA*~#cZqtqks$WnaPhlyek>=0#mTV1)C8QxP;ATHza%+p^H)Bf?0! zYSAxqK9`oGY$Kx@TC-p70Mt9v>F(G3abnySz3J|ww|*;wXu~V^>-E*Z-iA+kYpy-| z6Jio>9UyA?`aL2dZ9YxTC;(nZh9mAy=ps(dR7dOVG8R|?~5gw{kO~TJuv4m|1xI! z6sMKxOs~3R5DFBA*c$kH)7O>cSl!qWWt^6uFR{=klVyI@Sy_W5^|Q5IPM~~=tG=zl zmFb1pA`Jn;z9ayNewWqp`LkRYimnwXHnwk;(54QVMO;aC?Ds3U`PD@M=q=glb&i{e z$VWbIw|oLynk_zd2QlBG93-dH$1L-6sTxZ1W@KOg0Vsvf+5n+!~s#!KG_xDks?4 zA>_n1>Z4tA`lEQ=nd?{9YgrD?j@g6f*WuE6?*NAogidyrpu_(0%gWNj@^Nle{ zW?!f?*l%VemkvU}K56<|xh1c~p;kvbM|L|hf}hnIFVQp)g*za& zk}!dPW5gAOhj|Oz`i6Cll*l;`+tPDF5w1Vlog!77T_6mfZ>(^+D27Ld|HwgDd!Pps zqb8)JxSCv@=Bdq9mR4S=c%20RedW@R^3IjcN40WRY1ErsM;~rSA88^Gg?Gr+P8e2P zjT6|x09rNGs5ABr<7hSRifv^n6a140#h!d9F~2yCChMWNdMTxBoA{>IQu9~w^1T86?9Cc^1-?hb7|Pi z8%^m`r+FgiL5Ymqe^zaA3T7+jHQVB~mU@haMDJ$8TcxK|x|z!bn-e?woonvLyJzRk zt(S5}5#UZ7e~ON`V}q_;y>I`NguMAuuop?y`ZrnO@z01~Lw^)c7m}d$QfG$H*bqvh z%mX=m>4WhC2~M7=)lL1>)tr=^L7$S>JS%)V5~Myv^CZ0DdGDTKGiK-IfJJY6=D>+P-ye0>q_xWP6U;|n-m%zo=a+{tMV zX$I`zef)dib;omDQ3_^!KlzAOg8z_u81T8ZFnoJCi!XHlW`(8=ej$O&Ohk{VyY``3 zpBA!gu6hX#dG&*Mcw&$ANxnDV6HZK425#W&_xm<~v6+>u@MTVew5hnqp#bXNc=&n9JhKm=zs2 zT7|A{fqkl_<_xV-BamQirbvPiuy5F$Ow9IUHl9dr6U#iNc{M&#!A@on4Yn8<&?p0W z!l`!xl=flA-tKbkgsJuEx|i0u^1B8W`xn)+G0WjGE zn%6T1Z`(1N7oRbVB$xj=+gz&K1s}Hbw!5x$I$!*h-s?=Fmlg1Bssbg{8LOHKRCVDp#2IWeOG4A=Jf3 zXzQUBnDwX{z8be_Sg}_#v*Ip5?tdTDSg=7e z*fb3`#4k7C+lco|ab)#^%#h25>JSpUj~HdqA-E*+f>)V>W}1Oq#;|B?Z;-yeHCTC= z3N;zjhzNu+DlrceR*blfoUD0*=SietgV>3z5jd4NCp@-VY6L0muoHct`xU^2qgjYr zJjM~YO2?|qWPG!Kv;8XUCNGV$!N`7DdmF`hJ4 ze}%BBG1_?QLcm;xvXZN|mN@bRP!hGqH7n}}GY2+4CSo+OlMrw|U4RLk?6VT|<6A{%%35E{K1Rb;6Ft)1M3O;{am~bMek}Y@)fW|nF)BzBOMsis|3P|1kr^K>^ zV6|cv6OeNRAc%#s+D284fLJ2R$d)KDC{yh-1B^3dSao7X6XO|+#5M@_0}vL70FL+r zEfjjv374G81ah&4#Hb3XfK*ExijW3#!dI7mB?J45YM*IZ`=7mB78qo2O9rI~VVz0m zvFu?Tgb`qaiL_HQ*=sEgD-)-!A)zH=FqU~S3+0!NC)x~QN`tg4%@(c!$sA1ZxC3l<;bA!{nv?Jw zD)`s$2Dmth#B#cmh8(~Q!4oq)e&G4gOzRV1yt8c#VgZ z!jF_Xq2<^^%pS3Qz$-pK6AyKYjKp+~XYsKrL`kd`7#*H9!58Lg3~IzosWw%pW@w`0 z4%jB0S6iUw39QyGs&LN*M<5YBj2X*1#K1fmnPr5*^yUoVKx4KeVO3^C#Kd8U;uhf_ zg>1qn0%d2+8PDy0+yt?>Mr>Dv6{%xj&D`moCyT-aJpz%L9xtvP!an>|_i znT58BVEh^wM6R}T;tI=MN?PC`@vsYH2dZu>K2q^hBgG97t6ZwoV5w(`bt6^W>*GRG ziR~0|xq~H&0zIML#pWGWEZR3A zd4m&zIfx3)*@*`sxkb#^5vDOSI0dp!ax8OYnfK}svK#>025f~g->*=cA{@YVJgg9R z{j|nnbjuZRp;Sfi<*BLCMLTzqhpS0I2L3&m|p=OS0h$&n1TWMc}+0&PqjB#%r_Fp zKtsAgb(R+uhyc}Ctg*tVxWkq#B7Jqp#wCnI=Q~E4Vh}T#jLl+dpmPh5UD<%FL5w+E z1DF|sBh1YM}In`C{u9XP9LYRtpZ-It_xvsPKuF=P;m~oo`#X~XFq}UZk zm8oKVR3BcPEV=!T_Vi6J+$>IeQ0&^4>MNlI&WiWcpoJ@?N zLmAgCWmw6c^Vea)!9omBVlt~QCpE$Y!RUH}fEp{teY;d;uNC=uu!!87C)y?iFNs>Q zBLHn5Y3)8BAk9gfcE)2bIGXy{oNBu&qeekAg`&4Qt}7WsltxC4;7a7_B zVAW@k=i3co#@rTx*nl87n-eePzcrxEs_1Dn<)Oa8ikt} z=T;{n6^sH6H}a{7a~R?ik^vC0m|!a*kHp?s?SRm#w6d@wt|I!Ia!o@le%7uWh}@8{ zh#Q)DpI0_SZJ@}2Mo|=C<1G6PcuwSSi%~3Wt&BB zWnG|02T`G(*Y=d(rhb@lpfIY5u>=0Q>MHUyVg$sxzarH@aBFUz($Hm?kL?z@ zoD$iNc2eO~dqEI}S6yG$#u*aKz*Pw{w8^s(m30*`?Ee4=G8^SknV=9??{L(?_P3UH z*nWfJc?33zRir)SJ8cxRG7eNj@ykz$a|Uua8REt=rwrQ+q7xV4u|6l0lsRCtS*(60 zdw3D0T{ao1pHZ$#o;{LRVl>beDuf!iEZF_BL!GrCrFjQgfdWQxNK$OE{;It43C;r@ z2C%Oi9Pg1|DiFrhOvNC}a6*t3R2FRwer;oDMLbuWCybb}7ZVwBO!m!w)2jPc34n8I z=AIcOVJ+RUghO+`1e%NtnCu#5oT1h>Og|uIYV3cFJZK{hQMNf*ra@mtO=8-y?O9kmbue(b z)n9g*oRWP)uEqUg49^}O!-GhGfdes!0GKm`n3q|qo68Tps;H>(ivWt|RntRWvd$qe zMHBSmG33B;Gwi1)sisw4iaCKXAB6-?nw25%26Y*#+}KL>4YGAA!EOyWnXEv{0}YE+ zmBX3jf%PlIA+!GgWAWkhD1Mq)w&d>X)Z!6H8pQO{<|5`pqvPW$saVCoY$6422gp zZ23j=WQyk?IT)+Rs07no20RsX91q25N1OfmB0tWyA z00I#qF%m&h0}vu%aWa9Cp%WlNQev^eBQQX6(c$m}GlG&(@g!hlqOufILqu?MlfvTv z+5iXv0RRR+0{;Mqr~d$@wf_Lhl~uQ?UD5vllaBe_WnIpp^6lo7jCB?(H!@502~-_){!yPv6aH$nTFDxj;n6r9jHsxFE&Q@XqD z{{SBaNL5pP)!i{o(kQVJDkHyj6%|)4(7A7_o%}ub6f_;gaG+#{Oc4m`@?ZNK4bPg3fN zCQy?jnkVG6E9g)zoIWs)5ip{G-!)OgWN1RI#*3P!aq=Bpf0|!9&tE8&;~; z8)L~e!qbw-rikWB=2ohvvb4y`e4HW`GszJ4aYu^c)nW7Iuz+M-rzwb9tSItC{{SU? zl-ZY9iG@!UivykU71=7H>r?pZ7**e0)e4T?#Hlz`LXfH#5aH^vU8?|w@|axw3jYB7 zsdI3hZ7L5FIoy7$hl!v3e`M$sm{bVpcXL=ut&!*f=Br&*OyaRki4b7UP?+sjc&SZM z47N++RTZirnFUmvT@%q_J=&_ef+bs>QQ8w*^+3(DMnbOR$xWEmBvT-)tGeDRwpAr( z)Rj-hT~$+kB~&Xag+i*ho*`?}i=vj)DxfN((c|h+?KJ*n1?)Vr)pom{xNx+L*V?ts z_{LwTQ|(1Ga)BaK_0v@?9YT_*p>=d8aM@ivdqn}t5XC$KoYiicOMwU*TN06~0ha2Y zg2I5H(Ir(9Dise2^4*5xBx7_PVip*NsVa5JM`dv=yZ2Ij7cQTUpx;GTE32phK`Mv2 zMHGs)R2bE(jot`0RYIZKbZt~@-L|HxFazMbI7*=eg~H+ox>T@5W;v)e1v<0M8BHxb zYNDjTPU?pw=Vx^6E$&U3HpGBoi_e-eYnTbLE#`$|6;b4=ZQg5z*`N=$*hbj2LqdEJ zY|v}Hpqe4hr3U{1_Q#rwlAv8lJF4|}_%7biqar)iBk^?gY^nE-wNy0bs@{kUk&A8)Z<^wEa-K5- z#bQLFP2EKIy7mgUHFD4{89fx?!a3a;Yv0bw8?uf|j*5@+Q}Xf&jL>pap7mNriyqCcT>OK||)%HSY@l#o<^^s&!w=uqc>N(PLo zj6g@iq0c?-RCnmTBV!euTAMUEqR)a*11cvDNVH@ic%lZ22?l291xuJg^ff}4vSqlC zn5(i*V@boGUCwtTBBfx2V0b7b%p2;bEmqOOaR9Iy;V_1osa5EUyO%Y81&|$7ad#+a zXq?K1;0`&3IHZ-rE16Q|U&dS1;ZkaJ3SDg!2;FI(OA#ZYh+n!Hb6hg+_?2CeYYBJq<#so55@CR_WDDmf`)ANZ6A?a$xa+6P!xw^SgDLByFe4j<>a=H z7mw9A>4U`=UaIq=n$5sMDGj0}loy28(v~W_nVzM3Ov&MDFzFfL6@a05YiPx*IB!gjEJJH7;PBI2t*^x8`zdn9v{@B|;d0(i9Ap0Wnk^ig8sA z5L!nUB+}})6%>MJLZy~y5{N~hc8siDqJ|nCL2(y6-Q1TJ%jmH`gnU_`a)pS<3q~P4 z1LmQD98V=*@ihvRU9H?I1eHLl+VfIs?*OZfMq&iKMVUeQHoQf_NQZ5l=&DRyflU=IB zv9!1!X{Lcr+xl`GJ+{$y+UCtcTs6aENge14N1DSNiR7apx-(8LvW*bJWv@T>qz3@r z&_p%V%23542=YvOz1SmFRtzy9_@gu=L=97kt3oa#h7~FFMU9Wmq9wJ+CbhV+YMa>s zww+Ozo@$NZw34vm9MpSS1PBV^jKgHg>f0=xOQwr;jXx4aR+sB zoM)h-s*ci{4&|TXE~y-{`Pvtg4hiEeX6^mxgq30-YGC2kPx}jH7QlA-B4W_hJ z%%-$y;DkiXsiKomjYzHZD-|zv;BTg@g~ga_D7I^2n$k!`#L08G77WS6=BnErQQ?TP zoZbA5(4)~B`$y)3L8(|8Z=#1wpOXdQRn6ISGVc_95Qlr%WcaQpZ4XkRxRuvpdKCe` z>{W!QcF1_G7YeCAhqBXrD3!z%(Lh&RMvA8nCH8QkWz~Bi^<9=8s>A$UD!V!4ZmX7I zNv!Cw;{~IzbXA+&2?@MS*}1)v9W&gxwvFlngCm-(6R-U&-OUPkl~Ghqp$3g54C;*R z&E|@)XIx9zklTl;^kv-GW;_+=xnUw7rm3`@qYM?pO5aYY^r&)*&C^5x%(Y$SVMB@o#cgbvlw@NC6Tdh~l1dNydTTJM^# zXHu(=tSXNk_7aG`p>{i^e{K{IRR`j0zN(d#7-exmQ-CT(Gd2o*YjsZmxCe@#q$^D5 zREEE(Dvo=4nU&aS8rfn5wDVHeH2(mQD+{o}M)-6pn+I2+LRb`G^IG|?AOT!?t4#Am z=PT_F7Uii^v)TUuwN>0*7Fdjbt=0pHryy*cHjdi8%E`hj`2;t14w;BFe{P$ zP;;fEQU3s{2WeKbv@;;Bt4D7Xf{R`wkW77N+E|@u4i6QGrg_4Hdafw)+Jr$^^p&A)kfz3f`QjKTuu-z86$_Y^4$v)Z21*aS)ZB?AoSHzF_#k@kb?=w z7v&nL#f`$JK!JOw4}t-+MZ!e+s4eKYNgtZZ4K&ek+$)Sqw@~4W=Bmz)l~Wy7XQfev z`~Dics;D<{Pem}Tyi7VP1 z%TzV5Jkc0uTCNN|%ZO1>KvTm_3NfaNlupBlN;7|XaNAN54L%4|4yuMnDOM=`8llY* z)fn0iFJ%WbaJr6It5gYSb|GZwZwU}UL`3H{g+M0iAcnU+N*wUDPh>d5Q?VYX`_-_z z)foet4Ub52(E|*Lq-3izX;@}+s>D$l>1Bnt3V`^oH(@%JhiogXQC0}#s$_JiQX${s zs@2*2C6$#`4((LYXrgJlgT+goR5a$Hkt?v+Ih@o6g$E4s6kOJS?vDr6RipeusC@Dv zR}nf%U52;wyXd%_z=mMA1kh-+Wq3&E+YO#7l0USJB)kU*c zqgQCD#7NRn;9((qJtB(>+5<`z233yfET=`v_JCglPlP#M+B?CJz}EJC9ccP6!l z>hRW!>b@P%lTWI;+4X2q?5PHeG(M^!VSi^XO7zunUsc)*FBRE8g|e!Fx~hr-sH&AN zqMNP{_AB0Spk0disq$5_paHzq3jpaV>?e~}UJH$--&g!aeoYfW+P=OX$+mhsn z4#^4x%oUL--j0;0O>Sgj?^w>MyCbQU2E}&U3IJGC<+G{kxH~|2g?Ay@`AVr=E?9Pr zk@7|hoAC<5k4nSLP^>46Xsz6=e-Tt}RE2e-;qh6m7ANGsA9O2l=ncsMMD>*}X3uBM%4fY}40hJH+B=8CH6gf9{+nz2E;3x>>@vP3bN5C{2M;o3 zO1N$ejYsHKJug-guFLvaCMBIQ7@R6wj% zA9iq+=T5~|RS7GJ7|Nc^-? zUj9I4sQBaplKDK`qtZtvpV?OzafS9D)3p3gGO_KAGH%a?d(;xpPr8YzXx4#7NS6#> zC6EdOX~|kU&@x+k|TqRH#*Kl&UH1J1MKy898 zk*O(-ahRbR1%%X_ELdss&23}ap5$a-&Z=-O!0Nu$&W!@YvjM%!4#C)+@L53Z1qAT; zDZzZ=rBh{u49QVasRW_ytJleL@l#1xpTgUIs-am~N~(iEn#R~HZmf1jN_Lw{a-!`u zqs3MAAh`0X?6JR^7xv4WsF+dc@p6@BgYbm~#<$F`v)PWWs-2^{ zE-0RgTsT6IBa-69^ifD$K=>}fDyP9wvU!z58iX^|(zu#>ud~`j?4_8^57IzG0k3A?f6S76~$F;^<9#vuDEhuWmO0~wMO0Rh`>-U z$*b))hBQhk~Vdl1&o9xI!@EhWQ*H!kC!5B#it!wg4(MXoN+HchpQlK_9D;Cw~ z=1Pgf9`d8S6&L8SGFW~R0>o>l8vfb&ssgNfDuqwOTXzszZtZU3Fpd?(U}>r$qNeh+ z87K{y)k9)E6FRw+~u7CxX`%0M#2a=>ta;EBJ4n=VSIfXMtuW|&m z56XvbIj+xw;XXH0Gt?qWJJ_O3d`k^`uCU`FP7_Zx+5ME4!LlwqT=i26hFa*DgkqKA zBqCy;HHT)t~0dp@G={0QRHX8d+D1;V8>jXNJ{XtO2?jqxPg14<(u5fTR;H z_k!j!RK!|t2B`X4%W9u$3MhVui;@tCH$rV0sSW_;D8S)R=GUs?qLx>hpmrWc7iFWp zKy;~8D#91`Sb7Cbm;TjO{vMmExP3}0qsd9`4PBTeRo)A=115zHY-M$cDm1WTFTr?lMVVs6xI!L-$I;KYQ$KGhX8f}l`lr%My6B#Y(A zaEj`xz@i8`Co}>b%|nC~+)fG|IDO()CU#@pSlxv6Q(uxH?*pThmj`nzfb4wMYGn8> z7OuydoI6hCIHRnqtNtRYgOX)TuQ544|3u8s=R-Zh1&WBKH6rs)mAftg2$fgqBS9wO40y zBiT@DZXs

f|&~IAc}CR|OMw56`mw3YWx@Wa&4x5npMG=Ao>vri+gU%~8j?ibl&3 zTw4G-QFdM4xK7Y=D)d<^yS@sts|u&#>${Sj`UT4%jNL;->P~Br7CLR zsl)j~lJ4nIiURHU{()kCAl+Pj7*NfV9(Bvj2I8l>#?eR>=&I$!N7YFmRhVx@Cx^_c zFrM}Omi2L8-Pn)HS02&6Awyps0^zUbD~ZSb3eGjcR7gD3GrkboYm}}j+D>q>?BU#{ z9>9)E6DFCa?L}jfh6#>HVZ$)6xC0^r;w^7-^nx~kh{Bj%bP4k-fFA`rMvnDXH@#SQ zDj4xx1fGsjwQeqZ>VMTB%aviLh3`kkxjeYO63k?36dX z*_An-^lMz_C4&(#|n;BbEQ)siqn}2 zF}dH7tBwa#i{!XA4?p7pRP5M(Cl8{RZCyN=SN{N*cYs}oH}h>&(lx1HZT(0;3n;iY z8&l3H=C~Th6XJ}z{^(c>aW|sFhX~A^(ZVq`Lxu=)h^HkE2ptp)l^LOmyp=^2peimV zJHc^mp!F&PN$B-oP2{~B{{X3Lv}WIB^-|C3b{jLNq>!;}!(Kh+JShIJ`g}vXs5`i5xSuPv;-XK?DzT@Y%&2D93%T&9+cmpE@v$FE zunl*@;?+Anw%27jJ-C7JSih#)$Qd!%n!exqeUUtd5~$syN4q$J;Mzwg(NVMcOsdoI z7F^Z$U2;Q20KYYc<1iK&kEurisW(^-nIKRk6655xgsI`Tg25GVou{e@WtFpKmi{i;w zYySZ8-^G97{{W^9=I`Bozx7P%GvpLw+8@f$)C<}+U!0dX_b6u1{f$&$j)D{j{{SD= zW_lG3aU7bUHhbi*%)2$j$URqTQ-ospvr^Yb79L9pcMSHEHm;I+s;&tNu{ssPJXAQm zDPbF3o?}Ie`g4W;;Z+gzRCjZ#3V_nhLw`cMEqCO<`jmYUBoB8k>!}91g@I z2i|0v=GN1`|jLBixGZwfl-c~KX*dK>yUp4;#huODd@m=04(rp}E?|7BQ zLq?geu-M7>v{-#6H|91CpEX}avpoLM-?HKxO{eE{c?H3TW;4LmiR7NB;eZ9)Y}8${ zhL8UC(w+;_8E`-TCewr->y$Qs(=)#h$Yj3V^6zvXnOuAf(^Z3D!5q!k6X%c=3lTR& z&~i{KD)qWCdnYqfiKPr0piSWpUP|h|B|xFmRUlV3uSv5f>VVpTizjNh_?fB8y<*t$ z4BEJ=Ji@I`M=-HJiaoDrDcTHU-xPOxotgB5mB(_eLc;DQ^ZF@mN0?FSSMXNdeb!OB zVMNtoYiX|~@gF5F_N;(zNv7^1R~Hd=f)%BED~lETD43?RbdWIKSU4pTS8mx?1hxN{{SV6%|^kN&-Y%FXj`=S++$`wu2>ol*`waz*^lvEsAf%Q zf#9j+soE$vA?8;YuecA)_;%_fS(CvPGZMgC5a8Ql^e` zeXN>!Cuwp@cbej(Q~_NyTwUnaNl>||lXY^3!j0grE!Z^d{)+?qj;c*sZ_oE!jgTaf z&rj4SY;50fJ(qM*wrOMi@8#DULx-~Vp$4H#Xhp)M2w!Pa?+dBBbqYeI^!zn9YOTIX z(5?7@DsNiS1uxuDG^I9m+O&0*Ihb6#i zp40SFcy?3bex*R8bk%*J;`cP_gZA@>*kHNnGFU`64NO^|sX?tAHUWAzR<)05-?AGW zooiF8l|Q!H#;8`NONyLI$bhC1=|zd7s6@C{e7DL3F26KGs(qpF57l5Nh1F}3Ywco- zHiB9Fvcb0FVeG$DaPe&Yrj_Y771f!kJLD>GpCwrZVi(Ww)!kk26;&v(M^mD;5XFI9q0^JbOc6y(Kog zZE|l#_KuwbFla928#`FMsBd{I`9xsDM{VVLHt_6QeO!aKv~RM_P#NelTpUk1jJhv> zYo3k$N}(=!x%w|c-|{+@kK1(_$ZClf91Q;e`xN)`!VntNpf`Mx!WqkKHTWH`L_@o5qqksvSZ+R{G=n=cVlwIaAa-O2157qOKSfc4YpzaMTn|rD4M#OQ zGd~GjzYA4Bq?&q_RZ!ad&ecgZK1z)Ovo2RIEy!xV$Og%QsVUO$g_HAGTi6s$XR~zK zrmGxR*$8JBDcEjoqh|qv?kmykXAiXh0BPxFvvY};mp zJuO{)I(1md>f(zwR?MHOE!lL%nf;UtXU$_BquBbDUdXGDC5||Eck=uCD*BpkqRDbppt*&I6q&1&qC}V-7BTFK9%~#>T2*Wc1KzrGSyxyl z#q1W)KzwP2h@x8~)KZ8$U+dH!aDozbly zAf&Z+Ovq zJ*~3%R6Dqg{aMvfWUN02C0lh>))rG&A$ZjjTO#0&KGlF)W)oSd;*?l3tRn2PK=+j{ zzApqa;v|!jYZ)VFKsoa%!0Rizs}6amoRhYm?HSkk^j?!|Mi)?K{!4^NqZ-!z6gInC z7Y}FQ3yK}IXrpp(GKq8hBZ=hz!41f2pepq!d)r$70Ez(iM3?kVbFNCmlPTVu?P<6p z)fiMmab4@A6)C4sj)kad1q$+AQJ(Zc@kQRyK9RkBq4iF~td50Y=AmHrhiCLsJT+Ld zSRGa!a$LS%q<#*{zNK%vN|RN`QtFVV!1h2^qjE&L;89f8$@s02MeXfJa6zVq5e9%Y z3_SvN&9)ahuJmIx>7LA>4s=9kRf+THo2RsN3C<0-4$ZClFGaKB+u7t~ebAa3T6V?Q zar>PXjhs#=L>kHq!g(Id$8``HnkKTr=BJW#{04~pZh_E#>Cg9gzeE>w?(*JuIStXM zg(;w*HK9!!Bh(bZst4XiZrUeHHyeXk!8LUFHMDIPu>F|#^PW*|plc9P?7{l8ug|r+j zXviV4*@^z8*CW{EIis`L5$O+aXg)z~N5oNUW(V`?wns|%D7xq;>Yn47^kDM2{&PpF zV?EaS-8y}zc@O(Tq9Fy>e{>*q3Uyr&<{?VvoCVigC!VDNs}Mg5ol}o`F3S`ih#cRu z!hbV=={aH>t7Af--*k(F9NJHktq!WciKgl+$`V4dbMV|CU;h9}=$y?*B@P{y!le#M zEfrT+qm#L*_)~SVz}vbL_DufVQ{v{Y_OB8?58XMf{_DF(-5Md$Y}za;ww*%SMTNR1 zbs(=Lv)TKNgZZw1oR-XU!g@+`^>HCLNS_o3Eyh)Bh$wSiP#I72=C+N|%1?T5_ezkV z6QVQOQ>a@sN+|MOMOMlF=?CfbPH#}7byhCz_;va2gJnwZpIl;A`>H5dM(S_jD8v$V z$wVHel4G+zEAqlP* z7fx{N@>@-0ELmsbt76rMl4n_6(w)hnH@#0JM^w^?IjY{t-9LZTPU>zxr0ZSvs<*64 z#N;FIXgAeB3o|Q(c_BHKF+is7&1*hr2Edw?3^X=Gdm@Xt2=D!qj(9~o)X9skE4tok zu05!yRkhkat^WYzi9M&9pyt`40y$Et#yCx31IYyM%8y#Dqy69-K4;{g(QOaFeu_CM z_fvVH87b3qo2a6Z?2wL7;7~fzbqV@M)echB)0`%^xea9Zr#B}GE{;SjTJ9=Q8ie1< zH%|Tq#Z_fpe(DqjK~-B0>-(oWB^<|sAmmU#6d>Jg~b zPo-gC%RXtq-wk|`X!b;Sr`Ey;TBhRhM!w&Y*$#5K0}2!^vveV&{8rPDTTYYmMQr5} zsa?{XGZs?ZmM{G zrYpKX0nI|=@KBC?WMUawaHJ1G(1#)B@T$UIa?|Bl;Z2DB)sIUf)pHAD`GPu=i8v2!h7ssn;D|o*4$)U7%t%euc`c_Q`Y1vz+9(lt zo~RLEDEN4N(Gz0OgO$pH&9b6 z!%{jb*dg$!_z?o`Mzm1uip7-G4IFBm4o(T^9U1gMSlv%Whny{>`%9vqUQzW5JEQI; zRIOB-lxr*Z`^s;o;-V^0liab=V}^a6Y8zjw#A_L9jP{Xm0xaHWzs%;RTcRu=D*GU( zzeD1w*~9f!_d)VS-s+pfP_09&@;i_-ps2Zw)S=>?ct0slnsAbwG-||dux3af z&n0`GqHp^h&~;ljs%KS?k(AeHpP^wWo=}E_&qW$fig5U(EF#uYG64*Vc@%uv0b{Q2 zl|<(#cqsRZ??aLwaCa0DM|t%soYrCm0HY z#cEYWO2kC2F9leilAqvHW{Y!GpQ%xsyn=5ub0Mo*eAAA7kcO8{KIl%2g%H*u5_A1i zhx0XyC?$6%H^CjBf^&6g@l90{4Z@q$nL*s}tX}a?Lz7U8izykSTUj_2Ui9P!i=vbF zr_?+3DyEumaU-VFc=Epx;)gZqkX|sE#bY`rYc`d&O24B_{a=IhBAc zGx@nAPaM^M=}Yc@NV(R!adq8A+=0;ISY+1aJ^7L#BGK5=qcsC2WX;7=D1A=(OLLa z<-Zk_xzhW7P>93+M6BH?OO6?;4#h$RjZdB zl^_C|)~q0+V})Rt)S%lnV#u5=3lYr|-{gk4f-cg?Ip4`&6OB`T8vg)16UJ?Z{a2&+ zn;tJ7jye{nmzkwqU{ z2=Q3F7xJhAMWs$9#P8W*YjJ#$gMr7{HAXg#rQkyB@S7l7*U8Px^-$6^0GpK7cHQoZ z1v^3hnIh+69rs?>-cxmy16|ohs!f!;wWf+`)f_TOvC&1!q)Gw2Pby7JaKbvDlET|+ zsiz91>^UZpa$s~qHIxEz{{Yne(@D+6V$Gw?Gz-}Ra}hk&6PQ4sWjO<<>aTJ_7;{Yn zAiC%40rzDU-|G{($GHm1m(|Adj@M@q8H5f23Z^*~!+mlJUH;Q`Nn0@_ccED$-6}!0 zQ7GdWoym40vw24d70qHhR~Krt`|9p8cvOb;Qh1;8Qv2wtDVZviM-8>j$`~Z9Z#bXi z6uYv~UTOj2b$gTwHXx$oW{K#WG+gWv9|!1%Cm}YOC^Y_4$o~KeBieEKB7QSeK^Mtt zFDsZ%bYzG+%IakwEzKL}e;3~q>K*ZPOz|to=9?4FP#vI&WqCv{JLsem==pR*gTm-F zq%@1l1(>JaD}~Kz2(`hcE~?NwsDwc)4ieH;K>Nz>o!m7!EKjL63({>Jdnv7RE8uD^ zV7kL9!k$R+P+iHbG{u+EIL{F``Yag%#Gf6RK1vU_Qkdi3oONha_kM|^nwZZQdLIbC zMZ-$M6KftyC>%3Y6ZKNV;<1hy%OnA??70nY-tfN4=F4az%tUfEz3j;4a)SG0zBkAt)2ZY-**;5ecxQ#if4m^~(xqB+zF(O(Ibx<>U z$U&P?EN{DYiXwfPm3H|3{{RuLh~ZZXPjJ;hG1Dc}UEfq<>k;ll6=wecqQv`XrQ)$< z?9GQnGsAwgTR(<~Z25maYm-QCl7GbN^i%%b6*9VMl?$IG*)h(m zgh1q@4-&#qYn#F>=I7ON(OIjs>LUG{jvyxvsx6=mUXDTdmk}$ANK{Z4wfEOl9pziS zMctw$?(of87YZ1yn#GiOswq9UWVO`uLKOc1B}!*dy3ix7msN=#RYR@NeX=~1A1v0Z zLb!*5;XjLMS<2yZE_1j>YA67xRV&&Jh)J(`;Q$9D#Eq zGEo9!mTQCoIay~#C0a(1qY~{q$M0+4qjWVyz|I#90>a#`SIJtu)}>m$$oj8@X0Wot zhz$_ZzAL#7sl0@sbhkx_SSfAvT>_=eQ5on|6$L?|`?_v{_t#IO|0kMVFd%RK?eN%@S;?P`Ns2tfZBY3x=w$@JHEU=T(Y)Q84%)Oa(0m!G8Y$d1cGXb-T*G?k@$$e|2*!gjE20 zLh8Nhd(kCRGO#ya6xpkor8A-<*>O;N(av&R7v)5YN(5Xf7W__%Y1R=WqU-^r4A6&S z=c3>wPK2&Kq3$aSTNBA#5u(eQ#Mc-WTu%Q0V!I%>o=Lc#zQ<++=Qc*s@n)j#(WnOO z#}yACu(qRFR2}=djKHG5fLfs0fuJk1{{Xx1rb?Hoi0ZDmd{;HYUV5NPn;n}O~t>W(UnLgfk#is-Qp{{X;K*1GKiI>G`Y%^Iq^&2Uuev26h# z4B0>q5wGZ{6p@>%pv3B@3+tRBaJY3cx^h&3s=4s0IQFfb7ADz#i|D28L3tT?twC`Jc^eE(;%ul*9@tm#rz0Gx$gFq9n|}(_$^gv7K5s~(MC!I zwL!Wsk}h>zfEW2JMyuvjYJ~T$Qo5AK1msXTsnvAz3YK4AztC#CB=ajyh)_myqYxbz zF{xJZP}uHMv@TIsH=UIBM_MidKB!`5(lym~U|XV^PfD&JPeluwp#{N6GtEGb^-`gR zR~ZYc<*J|DKQtn)-g`$d0g%8I_B^{BXthKN%zs#E~JB?jx_ za;@~yQ^jD}_^-^P$wEE}$G8B#KtaDD;-@XgstPk)tb}!x2SvjN ztx=;!inS`H2Q?Ix#CJ-r(XO3xQI5Q_Uu7$>9aY7!Zn~iHhMcOqE)nLbx>aa0R>jP! zt8sN*5SllY%Bk@x8(FzrKmw*_ET{)n13RuB)cC7G%EQNJ~( zH6rR<7Oy@?Fk3RPv;LMA-EsrPV*Ul909sZYm32jTB8ruWEN!6_D;JO4!$hnvbSY^; zsNxk=h6M*`WpJ^@aR}z6zAC_~x`!aRWnU4=SWiV5Ypg4k0b+n$h!RSmTuy?ZSLnM+ zt|ORNh*3C2*Y3Wk8}d{Gpj?#A6(Y*_M2N0ROdmA|hra5_{3{kQ&J>eEt?N_QqLmq| zu6v+BQ7e)qfAaklXPMD*79Xmjlag;m)~Fsy;V^{PjFg23y+KqoKXJt+a8Yr2u4@oA zIjC!CsZbYhxkM@eR5)Q&zN-z+>Rnn&mBjAldaV?Y7YgDssZ-N+>Zj(RXjDj0Jaj@3 z%?eB^+mbvY5IA7&BrMd1Fq?6q4cWeZ#;JQS!Pk8HmHU}gQ?xz`jexL;{Zz#l6~1Uv zxdlpIsohsfp)2UKg$)O) zGv=y%7XX#SKyvm#I;!zziF16{84cvAHP9{%k?$(e-ngY?sbk`cLqVID1;k}Ex?t0| zsMEb&02UP=aVn~di+uP6RT(Ia(9osjO`5B^PJ7(CEJ{x5stwZTZWFp(B)T~c1gvsWsP`G!|NpifXRNcjO!!o7(N|74MFjgx{&PSS{#yl4l zxhgbuN9v_;s^6lQ1Ibnu7NtXn`AW2M3$U!G)PSq%4^$N+h~%Q$z*HV3TYn`$uPc`M z&vTG!jT8&6O55mE=evbTqPmcGRZeRq87|2ORa=!ZPUf00OIo@0E2;^-;Kbmfd2mdIiD?i3%rmCRb@osmW7!njWF%rOSn~xP)(_uJ|EB z;U!keVN&X)LWN3|4x$)Us-<%q^;Q5V3zD6mY+W7_JC-esxMqPgJFvNM$mF5HQiC#gCC1zv*y`L3DRReu@$^}5* zWkI-cZmh4=tID^I>UxKoVM?4^v2UoZ?rM}n6wdful=-SEE+tAVQw@Y@ zdz3z7BD#60utt!zMi&$XPS-^eiC-loI}g0-F3*8bJ0W(IfI6VjLvp*D6q~DWS|*=@l?b;@(YO*; z6+|~$3ZYd-RuWa}uVvbl5hm(NhJ_Pt&`*NIO#-&+BBKacl`>EfDwIQLcDYglofY}6 z=|BoBCZ-ja|HJ?&5CH%J0s#X91pxs80RR910096IAu&NwVQ~2m)7@yx_wC2{F;{{Zj@(J}u3gOc`3`Vp*Zy=%k_2WxJv&3&3fu{>DnHQJiv^7d;0tOGcQ3W zpr*Yjn2d8QvS*pw(7E(qrF!Sk+_r8GeJ?XH)Ldd-BBJMxAj4&qD_(}xQ7chdjp_~P zos+yQk$@#^txmJ0Lnpymh=RD9eRh*ktKiN)O${( zy~Y;_w^)vLmgT9D#wO@CA1P+dMCXZKLx?(6OPFAQ9Uys9-MkqPg8G7ijE>ngBdOASi~njfF=I`;PWrgYcnIL zIxz}4$}=jG1>Hf+nSyo?5VdYyliq}}@<&5^gyTnei$S@;SJD+uBCsnQ%r97*sHm!# zH>u2S;^_w}&Ayj7moL)c=^SdjE8UB^6LC0KAegOiDoRBb>gjtXV~6$>`@^Z+Uh7i*~tLR)j?4eyo% zuK0?yY$v>~4Z-)9teMM+f$uh4BL$>dIhl>$<`VsH(DZ$G=}yDC-F`Jihw5Oz}rHjPL{{WLwL5vZj8`)t7(V=E#?qFII5XyO=EXufE7_GLL-8tN^ zQdG;iXy>fJ!#qWYoy`ojR*8T`Ic{A;9%5R}Mgk_%@hA+Bx&psK9H+S!NjW?&&}%wNc5g?mf7=e$fdM_HKMSn29q zuR>I*Rp_o^>iX0@S>yiz9)I&*hoX86yNq8*$E1}81#lr|d(LDtLd+dn=wPJRhnf8vA@tE6|1}*;p&e&$y^MJg;N} z#0?-miT?nOaXkM3<@Cq?eEKub{_D_~l`GJX{y=USTn#er2Z+a6F_vL?+%tW^(Fg;W zRc(Y(cM2g@8?23;Nrz{(%Z42s%ZFCZkO!xv!ImwaQOfT(#QOp~3cVsxZr&oQc7|TL zLLBwXXn#Zkfy~O+irfiDw6(BH)8K~?g`WK-C?X^A#9?EeAtWs|8*OEl<_*-xE5p1^ zowGI!^?`x6t1Qi?Dn5hIoOXho`jKPla^o-8#ASLF7){S_{27Vr$J4K*5Tv}!!G^Kb z5EjN%4EC3@Q#W4Hj*l}lOf*uRWo=>K({z@vhE=dp#4uh7nN5quN;>DPtabR*vqOVk zC2P%nppN3s?k4&|w~0##-Q|aNI!%l5xz*nkOP1$Euy(o#t z7D|%mfjci~3M;nR`r``IrH8lIwEEv-FLgmlyv4 z7vKK?kq`VI_2As#>MLHJXl%&}#>M*PXzS$k25d|?@_ z`P5i+_nk|-GZSgSq1?1tH!4o9dPf@%Zek0(@rYHz`%5UczDamKunEE(b&GFTT^*)D zH^ilPg*bpKqnHDCQC*0^8^;POtOB;Pl(ZQfrft#ph~37BBAO5&yJn>`f@h;#m}VXnLpg@(n@5ak^qbf#y#sJNGJer8)Y%yYk|Pe91Q zm(*{p;vyHI#9@DZ;Y!weZ7s9k@Auj*lN~s9mretb-W0UL)pdqX0KM312vMJDMulC$ zu8^6za?M%3M@f`)#hgSSuyZqOJ>unfn#Z~+#(xu2SI=m-xYfWW10v#Lo7G)31927NRu&4hGZG3`wuy}&h!;+~#WU<-xUpFl=SHU{;*pF6LDXX^ z$*6!(=S}!wsoZ3*lDL)ukiu{MnE|I+(l~;-0kXR3?J}y*Fwu~2q7FxfJI5Se2mzNt=3w%<@q0dD&Z!Z*;r`|g^)&pK-26uxs*aLB>@9;z*dO`!l?B24&Ahl= zbRz7RW-4j;an>-Rv};u=Y-yC?a`i!Lu6jbn&;e~c#>ilWC&-+6hGv%-aZOBze1tHQ zIIPb!=d5_$yGH?6EGlh}Gayp7-#JSh;X<7y_O8R)ajsngIf3O4f-AE$&)}jiDsXcH z7uk!+=!2)GJ!&u>ij^43o>Jin?HIKb;txYE0|+Axw1<3&)~3lh%9uU7%5*+wB0Ej= zdq;BiV2dllq6IIRSu*g=FO~wMS4a!mjTM$uA2upS%mD0X0uufEIed|xheg@nugp-{ zt)>q=Ji)h^Lp{t3);^>cxJN0kGRo{i!)7>@HuNmfd5K8Ff)W_&JdJ^JgXC39kdY4xji_OaH*xy-Gh??$8 z5p=7r35;E!e0;#ITfA~Ax}r5@N1>ULT;Ro$_m@)U@QaLF=$@78m;V5lth~$h9Hsqy zVvl(ABBhW6(k9p;oaz_`-XnOl6f^dLo6n?nFL;5I<`<6vnU@0VEomN!aGpytOY^1O zN8y8-)83$VgUG()E|xyF4|tc?)e(eGMuz?(R(b%T^A#G>kG!lvK&Z)0T}g$9OXvNF z-~!)ctg>xecM{B%eYFKvR&xxBv8E`L91g6wFt&farb6|9@&}V+(Ot!ut*vVq`%Nf# zS=wY?w;Jw0hz8WQP&}|Y;!-Ip$9`|=2xCeM(q0)-b9}==PI;NRuG))h3U5Iu2HXTP zY|rK{R*I*z(XSYlb!Z0tpbiS9lGG~jadk8W5c?huW5$mVYqMTz35Wv(_9He>faKAQ zs}n3zE@J-x#mkSS`cyqCb1%}v>jiOT#OS5v4UIDb-{K>e+BHek+nWXNEl?}0xpsi^ zd(5pXMZnCjnnNItr9i!bjLSpN0_ss468KcGcrHjUaOrxUA%Nv_nRr=k_XlwKlyE(x zaoRZ;Bp7h)sU5VY56P#UHkSZ0Lp!u_TKV?K_iXs@yXqR)gC%iLUMvGJ@y zcQ>YNsd|p(lPsdP^8nP0P2<_$)(Yq-!oJAjgc7Tm`T?I_2a8%mfTi)H+>{SgyyU99s^Q^ntRV6 zjmvVuPP)U}@@6b0;`eS~#qpOm8X< z8-MI;x|gf|M74&&*Q|E%_^coYtD{o4saJiY0n-Doh}uE+Z_LQTv$sfri&p$|1(MZk zFEL?cYFQDq({{WJ{+#t0nVI7P`V|Nfd{KI@imW?#@u8Dq_hNa75`Y7dsVd)aCUNz`N7^#A6p@ogXZlJiO z^dN4^sb`3Lz?)xs)J|UMeykp94~BIh$zu3i%<_*L5i{iIm8zZT`b0B}g~U##jTcNU zD)+t^w@@w#bjSJ$NlhLLzSy4!ax>h-r-a!oL6!rniBJ-7jsF0Ng981OSRmr2zsP8V zC^h}l?+hKKJ81huC!Bj1QJjh$(JdD=$C*`PFzV_+8F^8A6A_doC9>#O_n4BViARf= zYoWV1ANdf?sqE=A24Q$(Il$ft#+eJ=Iqj%XJMePc-kt)Qf}nZq!h^4 zF$d%O1{K8g5nh&hE?E}g3_T^I5?UBuVZ_6VxEE{_fp&zmP!_M^ zP-(96#fp5*&3QKkfvyz6D%b3jdV;uX-W6IMNCQ=#OqlxkTXSQ=sQo5h zno(a#k|JQ^?{hP^A6A=+PPQZPViauQ^fd=M8<*xZ--i-iCm~J5V;euaz@>v)r1RPh zZIm9dT3-XBZ``G3IA?jnD2Zc3^&dtrlAUY$g9%ZeDw(({bsb=`vXHwvedb)Drzan? z>nf*Z2UjQ7F=|xm{{WehPERF;idJJ7A{!1r&|Meo0ZG!eV~N8+FV0X=C(X_W-V>VW%CST=`{V%Ct~sx zML25|3V0ILQF)KZ_Ljmhww8me!i}5p=*%TR&!Igt>4^GMEcAA;OFa`{IA+N^5TlAB zUF9tZ2H2F!k9kwXH#7?_za)8T?9Mp%btx>D2Ii&Il$FE=sHtlBtHMeZM+LsoY20N? zW&zslKi<&sX~D7cGJG(|UT))A4Y!tHbU! z^&nebU~O;f_l#{|GUf}iyRgx&uznPH=k+gZ6@gdC#3D*E^;*Qtb~%hFg{PQ`olkc# zE2ArSa>J@1V5n|zC^&10h}>$-6IzB;mMdKbQcNCQJVp;F0s<77mdhp88rjb$8^h3B(#USHFMuc22mWV5*|+GQ*|T;rycw2d!o_ol4nGuQA2S zg}?C_uTein+f;R63`N5I1qMf@2d|40sHkPqYCl879A<60z%o`Ysz{8`{4yz#}ug8$g{k z1Ymj&7{6ZKJ%a#7-H)&yOPeZsE#>}NlXJ;;ZJ|LJ`rOke<~A^xbuJ2u{Dusm0wC#& zGy&VpzJPDEs#k@q4LzOWOh-0G{T1n6l_~_Nk>~2~>24+_{jd zo38X`9I-l5@c#g_51ir9q&Mz~^eTx$9+nGYR>T=hT%;N=Z!+YXCE{p(9}DX-45PuW z@N=YgsLM%pZY|H4D|NVzB@NWC&_@Loqt`G;Us;-f@~vbKz#Tp#?d$@k9YenaNyK=l z!J}S|UiVlp;v}O^3Jh8=mD?9s9PtT*lM!GS0qRtL@pEa2_muQ7W=MFBPeqSQqGjNW zh4xCV#HFpBh?W#WZ%fVw3>&c+te(=mfONUZ(^%#V)vUce<_NXOO10mh$9l_ekd^^H zKe0A4!P%g(oFfviM36|GeQC)qDIH?`$+S1QN@7!WrD_W=nWj+$NfB_-N#$^O9R z7X{~+sBPGvP>jfT*7{0Ik19+%L+Z{D;4rMC$MY?%`DTV$C+A%1dX`U8{*q;fPe~E_#ilCKxHRwi!8WBzj_(j1D)wT7( z@g3YQDhkD+fHPhCPP&~xGh&z93@v-t7lEwxnX@%^MNh0MUELe-m@5}jzDMR+tT;A( z;j-GQjoeN%C(YECO}>fu2|&g%oBqrO_~QPiA$ASNGNYnmW>wxU z8=U|=v9sh~@0m_$cXa+?C3fKP-Um|Mqh@LT)W#LnhHtGvIdBC7-x1!^o7Du-VM^hS z^8^8;d`OZp0X9}YyLX7_5mwgQqB)0d%$-l9SPBXiFsn}ECaNs1vL*;%`3ZL%E?ur{ zgCAbkRFc(fw+w!hF&9H^eX{_2gLd7y!823qYj=hBlnzx?<#o7jJ7{(xZs&7+uu19z z0vAE85x#;`5!}mi(AW#o>Qx8gw={DLjzF8jW@!hwM(3m6Q;RvoSYSA1k6(%6IE^6j z7om)+mD(hzM}%zdQJy}WCF4BHa=52Zz-1Pq_?FmIv|F!fU>B&%Qze`?f>`KqT(-Ti zVArxYwIBf(&9^zxjkS*}3ye`R!Q`}M0)n!EucWb-n@N2TGVvQ_9LrydpCmlJhG)dC z1=PZM{mg8cL)CbRr0_qv^tf?L!TE77I4>DszcZhv(3JE5BEKX)1H+evIMv zh7f%nsJ;+7T-nNaki+W4jiv=f_{&-@lZCGk2sh{^MsP7otrr*;a%(eCaWpO&F5H@! zz&tf8RecVKKMI_`8icf$j6Y*3gHAT7N(f&dpC7201sd4jvQ3KQln30F4wc(+@@6W@ zwq$=W^5jgS@?s@hCm_)%f{Rvh<~=a-+bRGIlB!_^Mc-MK-IHS;{U9vTv8Fsk+*tAN_tY9}YB`J4NDk#;^ubf8=zUlSML1+%0(&HoD7wa@@ zFQ^|f&DHXKZVBa0nt6)zvDW>r3f_%dE#cWb0erV;MIJ^(A$JZP%306pmL0Aw&v@Sy z7gvY*jsR+MPiamD4i`y^>TqWt5{?7UVHcE)*!|`d=n`G`i;_7rRpJ{`WwSnK5-?z6 zt!Y&$^WHUi6U3z(yIYN*mX&yvZBkjDYG6aAHT=v}7>}M`h(HQb@X)CIsTtyf`-G;A z+VtxDN><^8BD^{6c=HVhwnb?R2tZZU%YjV9dlp}0^4j94jkaC8!ALV_sSMqn3=9?N zsAi0|I9F-3qF~tk%&FsrkGvl=U1=R6eIi|dej$2S#7@t=?HGBBT>4*JzxXIJ;e6_0 zh>BgzE+dsdYRsJA2Czkei-ihL0vOUOhrTFE#}xsKcuI>}9(b5(bt*4YIkzxf2PDz` z;%ElVZd_sZ1!L?p1L}eZRefML++g2VVEd?hdf(6@&sJ`y-eCfE0>gk8rO%uY1|IFg zw+BsK`GS;2V#+1!A<_J|4p8MrPrS&oj!g_rss-Z?evtv8&KYlAh)~D6hu@^6 zIKa{>F66kblC8W{wq?Ne4$r&<?Y;Ql5MkcEoM;QzTdUah?J4XfQd1Ed808=*ZH8H>2(hTGnC9-$FlJfwI9NSdsXb0!48{& z&NKU-)dfN{jUHgwxHv@ML5@}*GeY?U4|ssSg*UEci{33W2qU2FiSRu+mk8Y5Owzzr z7(J&XUu}C#<_l1aM^HQ`^-6L|!)!N0VP0Zjo245eI~xjawSeK{`?Mrv%=NCK77?#5 zWf|BSnN?@+ENgb(eatUpT%kud)x!(o7cOH_k>LQ2_!AeWW5-C>34+ZYVoe6zAM9}f znD?*9%Oj0*aKRVhC*C!ETrkU5x02(l+fzt$6Hm!Fi1M?3tPL}4GOL8dQ zbW61Fdgw7cFBg;X06yft(2|GG__5WOT$sN&f>Va>qExHcH|7Q1BeQk<{19Rx!>6&! z?GSWzx%?klMddcypVZYGp${kZ8ZWrFgf^}8=d5$W+qGU`5DwL3G>-%p3?@(8Dx~aP zFlQjj9bU5K>&Fqe2_DV$HfM2ZlxPh2Q)TtrciME*&xwq6`wdOF@u(_v@5k{uWNUZU zF$HiA9e&_QyTi32@k-h6ugvI`E6K$1yiOP;Ji~CEp8b2ktBA;G$qD>@RM7oEijjzy=kZ8Ua$38p!g{iYKhsfoZ3d4})iB371uMY>kFgaT?SnS@K3 zOVY$L94&q(8`#ks#0;xXyhOPBH0}P&;(o|V`8rW@#+b$*$L3}Dh|kC6D2CH|UrUx} z4%eSJnO3r^55_<*7Y*0H@8Tv$BCq9nn9{9t!BBWh0dU%T`mWI}%V8Y(g%+s42%_mj z5Jsp7W~Ih%9j0R#n4^Zad_hyKBb!OQF@dsnKzElsM-(~MT+x>qJt8DxLa-}KYh6Cj z*ReyFIb!C){vSz^2zLhm0J}pfQM+I`f(T9H!iUVI0}n@{AT4ZpK!7sdknz7%+%WPO zX^OhW;Vr>j!38^uc{|1lf!hzHgU0Tk8^u86uIa_SCc9+@RMIBu;%UdP~- z%8RBU&ObCRUm7_pmXUo`!wt*nE77S*e5X+w=@4v{WKzjf2>FhPORzyxKpP59D)btC;{O09E!u|K*R;|`GAd7K)HdUSb}y6b1OP79hx-*5CadeXFqszYLd!ayhaafW$ZSU$ zeV|s(?U*=b9%*pp{{UqVg@qAF)y}p(rB~U7eMp38SdQI%GS$Gj2f54zcn3d>w$kI7 z?4quXaQj2nhe3(BL3kYfAyy&ask=mP6ih~x>2)Y;bizgGsSUP7TkN%7Rfq#8WmmmSgZVV|$zb0E95t zSDqq=Md0imyUZNBm#K=p0hi1(p40MfB{i?6yab|ozuyeV*67jpn7#_+`=;Rh&!m`B z<53G=-0EQ5bH8av65%L$iDIty#p0%l$&Ui}7aiY8RVeL?jxrwS%xrGkP<^H=c$_L* z2yi0CL2lOOZ$Fy$n~hYbm=Gn3Z)Zw_ zn}lN6skWHU+Gb&Q_9ck(0S#8xS%$snA$L5=$9Sr)g8YJzSsS|sUf3Q;LV)5Oz+P6E ztcZ%te8!uzw8vyl(IxSXOFYIWej}KL+%^rKggtXE6Q_tRlVdT*?WRz>I!Rn2)x)y>OOGi{3T&?>P|a zQ;C$Lme<-OW9&m_5Pq?CKM-!zC~CgB}nz@A>S9VgF zSQpqOG-j=EA;>5>KS%BgmH_Az>-dR2Z0h}qXaIDTe(-*Pxrk98%P?aD#MpMPR26jf zY=aliFT|=k^^eZwJWA~Rnolb>{gF&9r*QOOK{uERFM%W414= z>xfDYWGa-J{fMPeXE~@wl%+UlnvYro21M!L3(qT!tF6LmGq_p2?fQp_(v-TDc>P9< zsQtpGbA}v3;X@V1YwHUHZpGQCfK(W$nl<^(Wjh^8r49)769cG+?73i7J&)rQC{?h` zs}j=|XN&e6eOS)SpaponO9esd#6${!IB(me?yod3Q(@Wsr3O&sm5SrfqZcrYVfqdYWBL`N$bbwNvFh}BNHqI74 zih%GzO+X(PnQdD5{m5_?d|NXBe+!J0wa9ymbCukBOnDoPJ@p zzW)HZ3H5_S>#DwBWf;3rnZr8aag8u+*BKbEV>d@cuV{6LMOp9GVVjOszr<`+ia&%n zrg&I+`^F`-$@;ouXLtDd z?meYg8RsKho^gUcIUK|;uu@F>&Z=4oikx8fC5 z<|(OIrkh+oVf>^>?PD0nj0n|#*u**~yhED=TzU*~#B>97na&#>h|(S^=`X9i=HF?Jk&A5a zQANdCY=2Udd{``~R6fO(KxG%K3nF;JgE-6sD;&%e*1Ub?L9?aln4%pFnB;6PIhGy7 zdNl>DV%469(hj&~;fzDd9l1>4pw=+^MzHe5gOqzXjlYx@7I(N^3$%gUMy+6!C%P6_ zz#w$5fsWCBft0yNQL2VjcWR$VZH0BGF$cKB+g09S)7m43JWY{v=xftRGNz(muJpub zl4GP!@e>e3L=iw8*efdzusGDrJiB{>|Y6rgGsrQXM68=^AmOJFL z<@yqkQ;Lu0wE3IE>o||xE;6-!>*~~@i!c*ke&9gdIx&;x1^ktkQo2gf{IZKBg;`VB ze$b*J{AMdM`s4Tsn-{_Mlr^)~5%ip)_%n1otiEPu0>wA+a_2i|iQ-_Mvf&nyx>9>e zFM(@WuM&-wo302iRjlp}*L ziiZ}Lfmh;Q8|D5+gWm%(M>sa@|GXGOU!qVK&~SgFVmyf5~B~TIVE&`CJh9l=PWWuW6TaT;wxg>p5`A^EyFJ> z#36+J_&3{jf zG^y&h?J(PR?EYD_K2OU7jYC!#4~b-=ohf1Vv<1VupNU-z+~b%dyf48L>flHo!Z2T9 z2aa%=1l?M@`XxxaaX_jgjT-Vp;;yUt-cULQa~YOda;7~8C~prS>^r6qbnSp%6T0*NFfKgYJ0~}9JBO(A%~rQQzzCG9B-DI z6S;dvURmP&dmMbjt4~g_r}R$!ojn9ZvRr|lF_5_CRYoc^9+fIFkBH}i;vH+=W!!he z1q@p1a3=xo^p6%A%hD&H`6a8hVH+T;$m{nlbZNsh{7R6Nt*h6}#bNgT=0Rn-Hhjw~ zXj6Lxr0dWh_F%ugH1f_{q+f}P0lUPN*$Qtnd8=5iV6`2SVLn>^jn}{J^kVur8&(2<2%o$~s@XZhn|w)U!L+ zN|ml{o+Gfgij+RQ8zSnJsItH_{{T@33ahjbX};V~tFI8OQ@q5)Y~VS}V7p(?U}Qtc zJ*IN>GJcs?W9Zq4U9?BrBzY6DzgaR*a?jv(;&9csYT);6u{RsR5Aywuxr*Kgdh)T-(+ zS$ixy&$lf0bbzW)15b~|OS|4T?EJ8T&bG;gPq>0NfS$`~lmPw&ZY5!#?e5}?s9l*t zf}JZt5VQv@f4J)l$Sr+@;-?wkHx{!mW!6Ahd{F*aL|S0>4xd<5;PuS*H7Nn7aCM2s ze<}*DO=iq;16wMlOq@9%F+?x2jB0znlJiy$Z$ts#S8cikD=NuTf<$02PH~}uL}6yk zb&G8$DXEKh&~Nn>3TK7+m)@`G!QAvMl6`00=953h}i_#EHbbYtIhn3D|5 ztfsvTQsZCvW)WIFFc2LEWr3{4Om1Dfx3n>oY`;{f!y5;Qn{@ak%fpF%U|X>Ab%OV$ ziIhIu`_3xduC#WQE(pmk{6)0ya`%`Rdv6V1V52%QFn3B#8^>Rd+`wq&fzEtPH-$T5 z`%8q=AH_kP26P`YGPqHMK>0>Hd-3;**A#7@S@J=7IpnR$3lU;v}8iBj`ruKN2%{6#H_sMK9x8~Vk@fP2b;Ra>>L zqF4(rF{PSjP4JZp9PE{)kY5LxaZ%2+C3+>Pxq%8*q2TnC4(Y#xFc{e6{2uV$>~H<{ zhgS7vf=a018iE54OV_y+(tMc1?u#^bULI)5i!L6JbQ6r><-`>`&wo-EE$qrWO3Zsp zy1e=l;GB9NK8sAxpH?TPB1V~foa+*dqpCn|RG>mlV2IV=Ory!uh;g-=vS%%a2KggFg-T|{* z$37#IzpX(l)%0iI(Q}o9Hxttls1?$6*f7hPk5w+!C2jYXElT9Io1v+DE_r8CZ}35& zdn{8#bQDQV)S=HVk>(?Buyh<(d7X7@3$K1yOpzHk8r;Pyq0NO7SMO^@he@1w){)$ES9I;SzgW7~m;o6ptnIctZ|JU~~5eeC!@a_t)&ofTt5=`lV7 z%*MAtdJyX2s_*>5&RY(#kdr0MDykg$fB{uFC$M+gRF^*xhCF`aRXWb+Q^V>*6FsHp z_c_?1$xBAg{n6HplQHMW{b1k9FE1T?oynOZtu#AzCAx_2FnA418JDJ?ybc#la{WO- z2du}SY59~!52!!zDD>?rCJb22%a(%NBo#3HMj3ldlQv@5)KfZQ8q&T9fn{~G1Eg~e zP1jk)=VMC>B+lvl+Y>NIL%qLosbc88PX7Sxy9@5N=;w)Ol}#$w*K9afYS*EeXu1$@ zn66!!t(5BTm!ZvlW@_*Yx1-i(8+<>mW2^pn?trhe{{RE0ct9YsV2?(Z=|XwKJSan9 zX|?`lH(xVIn=_Cycl|-GQ(pf7L%ebOS}F%aw5Mi*v-#!}g&R3cvv40O@F8A^IR5~O zAg;Vmu>J(t>c%?%03<5PeU3Bqub8#WU!nnK=g}UWaSdP&9n2gq$G_GG4K#t+B_fQj zW-xYYZ$9Isrg-Z&;e$?e;g@J~SId}<{TzZ>L5%+Z3qYJq1IgF>h}PRoKZw!oIEpcxF75lhT6Zkr?jb1glU!N zZaGSc>2l)2rCF$>by1D4v^oJSQ6jb~r75-PGh*&pYKN}Qt|bU6^k4N0@`{A2yG6&m z05p!Tb1iFZ0eTm2;vv7G`+X;9e+^^p8_x`0m2uV(+inl?XgXa8A6<`g`w#){(tGRs zwCRxPSAXorb-CF6C*oH6DjgF54M$h?4_2To;Y6*YpWIOh zzM`=AW@I`BD;Akv(X5s<7o+$6!4X^Y95J9Qv%e$lq6Tb;Cx&)8gF{6w1AvMdojK^c zumOJp`THH=V+_Nu@9PDiw(FPGh8@yXU{$vcSB%4{r(=PlP@Est{ASI%&gs8s=VvJK zm~8D^^5$m;T>EZZ#={QqNZvV{J+5pUSruAy$B41)eIpbu7M|1HG2w`}lOAOucv$TU zhQk2lcRQctrDn|?d%{&__jmI$rz2O=8=&!LPl1S0yH$NUK9%j>A~cyhKvn%mZAD!; z99j`A2b$7KK|bU4I)?UxJ;&+^sT}DZu{uFa`U>lh-OM)$Gje<*NGF`yQ7w*7p)M?? z`jDLZ6>jtRWey931zpv)3uL&KI32D#D+2HkH*Q!+MO5_N5PQH?#kDK*kM7XywXlum zrfGtL0gi?y!b>?Q*Y|m)0>RTk_m|}k4*vkL7&|R`)&9V$?W7MBYRnF-3RpwVLFSaW z?HrPKSgy%!Jr_KU#X?Jj+c!-8l_XGi6W#Q-g1bv$*J`}*zTcj$7nFC{zWKi)5L3LbA#&`OLs$` z@65t`eB}TXFdm=G)2k&q%(x;LIkCemjUEA~&0guGr|F+IdqiO5S7=#XtKtG%23=3) zCuRQt;J{SR)<@x^9aybPOHG+w7*unqs5ZUFm{aOT;FM$07-o@K=clE5=@xEKF&0R{ z#OodiXl?}C%Zr)U#^M3ppod^_#G(mBOk?pqor19yBV@d}IEuhaNSMVU2vx|}#G9+HMF+k-9WfO2u%zf#hT0=vVr)8HDETbeG_FT!l^di=(! z#h7(tXo`%o*?(sNDg#5XIuDZhBO5Yo_&oT_1hqmL7d_5*FkslY0ju)FEUp}$?{+n5 zQoJekmYk>a=H|hX=pDZVO%7|^QSeJ=CE`W;kBE$7U)*W)Qh?tDzuxmAIHlo={19Fa zPSIr6oH|tHj-BPc(Vd+!Ho7`C%@}6F0jzTIE2ov(fhsS&o49HH9lIO z3*sbjJLvf948(|yn>8|7r#>Y|(9#rTvm?`$;ws#}ml(O9Lqecb9}%52fR4wnNkVM- z=>r=fHflTIts!^~_e%91dI*lA7Vd~>X4Ge#{a~pan&|TNm?$dOn$Zc=zzd&0-QnRm zy@L1mV9hA)`uLsbHtWyk6`Z08(opC(Lw$%ifobhh+qOaC9olK}n5l+^_WP5Q!`SiN zlm!p5v+l9nv9oIZZ!Zz10qH*vSD(H-2jtuVg0OXW6#oF4<^AGb?8!0I%KcfHyTy9P zHV@ht;0QT8p8(ZBD!+((H)nQVf20pI{{SDt8e4Vwmt853Y)YwEg3accj%n*O7?&%X zeG?wv)r{>VekH+Q@&5pRlDi5MS0K#8)I8Q%J)zZ#UzGTd-!6O1#!Ld^!xpGd= z;XY#x3SeA#dytL6O?32a3QMm=F1hp;Th;_U3aB$F(j%m7n2Ujeb@47%Z`w7Rwr0mY z_?QyJjMev7Y!p5o(W2mi(sBz}t#%L;t!i7c$(;ho7&+ z1=9ENYu)GLtQbymhkPf^qxUVpNFET3525m7_bGm#M1K(2a1fStOVJ;YBE8*vlxN~S zwtdCKsfQ}}Tlk-pc`{5#zGdRl_bqp_^nis7DdHZ*CFC(~ddhJR&K+E~=i*z{mGHa# z+(72L+7CaEFg%?5L^gqB?g59UBB5@d$KnhFgg#|n&)^Aa#`=6kLCJ#$44*5B)Jk6Y zeNW6!275*>*h^!aN>L4ubKCr=I*rCpxAuOe+@(=}WKQ4Z>E>*Gy)_6}x%5dw`%mV; zhpm2(61@X~De5rx9;~#wPM0up&k-mt)@wKoz+cem5IwHMAn+@87cuX%J=bKn0_^nY ze|3#jHn$MKQttWr?-FaPb^c*-7Lj&;u$hr)W!zGnCSUh&nRDLvpYF3zO%;9h8ZmES zrMdlpe(X^^9fPC%L_-#O_qE~yBXnNIpLlXDza*u@A3~wg$7zvwZ?+}~@LzJy=(D0lAKPspQNq5 zf5+Aza34S4F^0Fm@;}+o*lj(B_ZG%(zU1Dn9?e7tK5u*bOPRl$LuS3z7m@{wt9pCH z4PB^B6s@<=Kfa{DbW5X0_9CVEL-d%+@>f>#+6+6x_kEyR3_{hNWxR($xQ-k{t(Y;H zktUlOqXc)68HZAwS5YEz{ugTed(TVr`@OCFIDizN|eFT^F7E ze^ku&&%D?ULD#3+qb#@*r_m0T{*&Be>c&f!s4@90lp)JLU@OiebUI}CjbNrOHo-lg+;_VR7>r{Y zX-sP>fR#uWQP6Ao_LSl_!jh%bCAXtTF$%DaFT{E!uR(s7UaNg5@#w1aFh#vu1ga>9A%<~x0l?!m-Z?&TDY!F%DUT={$dgiZ=B5s2R_lW$a{g3*a z%5rDKw?c;;nPzw^3-DBG7Hsn30nAk2si$?uS0bhCoXh2vW?_9N)d@$3^DmTKiFL`6 zDkyc9t<1ot;u^LZwx$Vy-_z*|acSlUH-_zZ;P3uM z1A}**?TUa>gz?M(HXjNgmW9*y`@@!K4*jJL)Ve!#`~6_n@~-FLMLE%ZQLJeAL{Qz6 zCm`&K31-|=UH#?(ztIp|l)czhTuYWOI6o5<`o+WW#O#4SpSTt;L+gYsZvtHJ1G?`5 zpLG6VsNtaRa5OR(-3GP{K6M)5r-PQy_6z$bW|qFOt~R27d6(H)SLnG_i`eKj57WMX z& z;=G|Y-mgr5l9G!_$8LVIjXwVXBR4h0Us!1t0zTh}2Wpj9R15*T@A`uRaqeABz80hA zrlEv$%b(o5t&3;h2ksX;>m{JQOYJobSoj>pH{<;6hNpi_&*b==i7xj!M#t}tubXqs z;c`{@#Co>Wl`M8x&N=Jv1*lgur1;t=HVBL|UHPeDH>vF^b~hapFzu(fClP`6e+b!J zx9I}4Yd(r#Bc{Lai9ki+f8O%@O2qQ-UkqulXKd^=tD6 z97n%dlAV^PBI43Dj#e=zi2T2 z#Nb$TJzBY@Hm)!Z5gDkjIsFB3%mUhL_=@-OJQ{E1d*46lw*k@wa zW*9nZ;nf+z&~x}oEfgY+aWt+&x*xSUODe36lGo znL6(aiJ!Hw+rLDouMJCwYr~yIidR_gAD2E;n8<~f8FBHIsP7XozPp4^K@$8CF>!g^ zUoPUsF?u;+3WUb*ZJo>!$K1fTEP4L`C1%&azQm^8I85{ygW6N8VXX)ORR9Njue1`l zyI)W1fz1`{%j&}D<%7j%<_leDJ7xS5;j&Zm#wPy&yJCt?f$%@b5L{dGT_10Wa=@9e z=w`mtCKdUA$ZWUI&-(^_ws*(rhl3;VF%c|U@42X`gCX-~Kf2E__cbh~XH!jP8iRp@jKP+-_!E#1kAL zLJM;G&2cU36>)N&J|){96av65*GaD^VD1D8 z=;8X4H_+^0Ws$|d!eH)3z8}<5<>0>yo$`K0{{RQfp^Boa_M%g0Sa-AjLEfxg2OnZ| z5cVd;QtQMk&r3dHx(Fclw*A;u?Lz#oKTou!ryz;(Eq{CckZS6gm4<~RwZ3w0yw0UG zp~(KGLYP^w;aD4aw!d;OQZK?+y;<5~IWp|=24uB%Oox$YymEF05V>dHiCrDius$Ga zS9LF)nN=D+qs(}ZCN21e;HpJ=#m~KcW3c%q0(P9lT+f0~n{-F;b~XOYjNOn z1p;8Nbhx!TT+X{rgsn_oij+Z+yXz6NdlGKf+{~_0!*t#M0FgTEJRqX9I_*R4E%5Zf ziM`cQbhH#MD94hTBQ|imUgfJs&3j5eX>MLHUGjIoVieMX_`kiP9Ywo=SN+;hipCx@ zQjwugmnYg2(<9eI^&{lyA9+qwk_b_f-ysv`r)Z_lv)r*>O-qUzuXrHNG?Y_UvS6w` z!Nc^J{u%c`q1NcGCP959=`ae`ADNKGF!V@pW0Bec%RkIOuXWFsYkgPdpc`8A-~OiM za6^wm!DOXO2h)WUX@&YtR2-+bp>1&c7`yrUB0ZOZq&h&d0k`0b>N6#@LC(~t(KY2{ zQa=v~8bz$uIsq()UgA<4T3RtG2+oHW8UohJ?kC;30#~%D8Ck5JPztia`X&1p?Ha@B ziN`jZO0PpO60;nT2jQGT*hRof%*E#AdrMZPb!oRH)T?W&TCuljFQjU+9y0e}03$hj zOg(bXFIF}?HoQTtyWOpRWfc|NBbWQEO5d`%-}we!%{oF?+F6F;&dAeyK66KR~~!wHRI@#3YUz#M19 zT3XrzaaeFt1L{P2(5+%Edf68Yu7)w{7&C!d^awNQBZwSM!`+Q&c}2Ucy+QaDscLau z^_d!2a_Zp72Lr6;RQ~{k%soEm?hJ}P{X0rD-Rmys;6w`M)OO{RjnFnaM(Bk+p;eiw zdEtK3U@sPVDf5|TsM7kcEXL8}a+oi8h^}%!49y(m`&aHav2V&3_PI*r*ZAlYP<&t) ze^}3u{Qm${#2OR(;tFNM^ey{+rL#2f^F!#0hhC5vnm-wr%)#6Fr&p z7uShzIUv$>h3a*<7VVELp}63i*ybVI4ZN_?IH5tgwakagYuK9CK|d0He0-f_%J-c! zBj5EL1|^nuDrUZ#%x!a;F%md05dIlaqtPilKG+ z9v#6=RaK+2%*;hoW+@G0f*wI=HG+*KI6}h;_Rim<{FZYqk5sJjolu3Da*zwN#q(@jWy=2mNIr^as<~NguqJ~j!4;>9@f=MWoDSc(?M3srUoKWi2h_k57<>SfGFhc^8L)Qeqy0~4Pd%~TZhCCtXAy*MWakL-?X+FrFjPIq)R%V}eUxl~a zL=`+6?Ee6>H8f@0_?DD1)OlaBS@DlA?}>u8SafKMZxol4k`v9J2UubXOc(Vf!^2?h z2sDd)>S(?Pcl?7h!tVv&VItWsyF`OUArlwLCMo$D978U_+~+)k>WA$%;obCJHsir#5g0u@fQ^e z?$?5O;6tN`vHX=j1O~ULca=q~k8ppd%+DFiE{)VK34tE#`j|khuIPV}HS{X_{{R>T z!h!8qQoZexJ%{&+W6j|1*#7GgxRN?DiWIQ*b%}JT?tAp}Dq*9OU-c+eJrA{t-dL-b z++royM$NO{JF{1UVPAw7k|uJ~I_7bObb!xyX-dxOOrh&$=uFoIO`!h(z%Ivlz1i^z zoGxy8b1>4((_REc9jc*HgQ{pm>N9W(9*c7+MXo|ngpQa*g%Rw1X4Da}pBEnorAPga zz&e!=8=5e6X*5j+--!4WzFyM6dq((TAXG06voIJ7&R~q(7nv7w`Cg9XwbZ_1!HiRB zc6A*;I6YT7sP~Cnw+tzRmIoZvqK}|}x%J52Bac7QA-loJD_5bHyj=ElGP*^e`wyhj z=3&Fp?K&s=I?Ivr$fyB=hYp=ySku?D75#B46<5x!eyEj}vi*rcX

SqOE^UKP(12 zXJP%+-n@N>_=C-8uRHoiI+s$BO7oYW;#ehu{{VC8fx-KhQ|1_w=q~9hm_#Q0hG*LJ z>LFs+X#W5*yQCtJ+N{wi@-P9~c4Cx83K?Yr9$;(I-V7$C@tnL%z31TieP-B>RIXTE z=K{L9bVNkPncQQMFw}CeRHv9=Y`Iz8R;tJM(l=nzol9X+x0lS^Lg%LqKzCu6)lP^Fj0PuyLU zb?eofLeO*5!dG*<)zFqe5tnnK1>0IW41Hy{x&U37ejuWB7Wi-N8BNC>_WuAnM((Ju zXxQ*SC4F-S3N+B`q(yV!Iqv`mHa=JZ(%^bru|>(b)KZL-9llEDSUXo5WtF8LqZc&t znI+6uN14L4H-7MvDigCbE&wQp<&;3o6$p?$7l`U{m{l@GQJpBt^K9ZHsD=X>%*9S2 ziA8-`T>k*Fsn#NxF$0VAT&*^4msoX|Q{1)7GF7K%UE&Kl13s#u#)4c|gxvk7A`@Ud zOdxX8U$~e6!TONZ1i1tBB)GsfP`8|Zy<6b5|UZL_G7y@;$o)Zs+|)YJf$C$(G=%U zc$IUqaPtDDCCej(sLqJ^B~sYMT_*L+i({f)kRZv^gxe5Q@?GKz><0H30#HoT`O+yu z-f%oXEDCI|hu%4InWizz4JFD31StEbG4~o`jw7p&d6C~Okrf6;;kmUvKlL)58k&Yn zJq{uDRNZulv#H{x*}sz#2u4MwXtYey&BOF5QQ0zGJ%8TV+7 zj&3jFAy*B4@l$8gZXF$8XECu(G}lM{AaRT>{pKmGD)p#2Fjy_91PHu5q7`O$_a@M# z;pBjuBguvW=?Am%Jcy0vSE~S-J*EN2h;ZLLO!CFBI3R`mMxJGLmv5o;VC^@QCV7NK zc9gsXDadnddy{?UBo3;z1+ZYYE>n--GXgj>t01jS%y)O0egfP--C#;$i@ENDV=q?mYW(7=`(Gge`#At)!E#<#?vhPdKgif%zN;QwRq4D09WBq~K(1Q4< zCEVt#IjjU~Y`8J)hR~qp%|e8{Fq;L@&Ly%6D!Rt=h%hyPKXx-(#gEKR0eq0&%fLSW z0PHKnyX0{fV(H1r{6u9;QPU_LnmioE2Nr9?IFW*JABYoXAkj-$AQROVuD~WXu^Ry9SI=@PU+qUZxiw75rg4PAa@JrrZo%>4pF<|BujdbTz zYDX z^287u8)@$q3h&U8hdFWDEuX8nXqckPkm~g*QDxLy#Oi`mA86bfn3fwB;Fh@o)?r)i z0BXV0nS4sVV6&eK&w0ii9Ahy1aC^$ehjCcsDbeOQ`DSwu*(1#-gGae) zrliwrlu*I@AIw-zuiN!H?PB=5uha{-2eS-3^tZO4!fYr#X@FV|r$!S;sPD+ZO##Qu z6q*c!ytAsKcMlQzM+_qz)WjK>aTv+*EQ5$aIEY+(v&vr4EuMrZdXD~uGVAog^Vu$9 zoXRcnV3fe}Datohn%-!eC#Oo(p%9^0QMB8;v_qHCOA53(C^<#v@SFgU=M> z=rQ?Aq2P6Z1-BmY3zlTAWJ?TB56UzYE^DXmc0}m`CdyZD>_$VgpYM5=oZ0OSU03RV z$*5XiMmua_AAJ3#0ZY@disT4A;z*-iyFg|oMY-_+ITJ5Y`iU{mPOxT^c!_)TsKuIf zhQpbdm}@Q=oMcXkb&ThZutO1fgKTqMz&V*WF|{!gMbk*lAi8Tq0szTfw8F;gt==N7 zSY9WECI^4=3<2cQ`I-ZHKk85)IOFDF5UdGC&aSCMf zobzB2_RMyTQ%y^G;t`Uo*27aWB*n<{Zigp-!Q9%psR?-k?TNV8Ux5 zBDkExolKh?G5eoxEhUEg-9lDCwI1YR@uhp6%b66}Qu^i}k<(3kN(^Gl+CP%r<~72? z2hue|$S+ZCS-Qy1c{(fQE2hP|IPo^1sd;#Ggn)TkEc+$%O7UuwMK^P;bKa!#mi8k#&D>@Sl)P_55s~16-vML} zHp5b{HGve@XzebjWR-x#$lZL94P6?FDp>*pw@a6oUcQ~@ldESWGu+a6-1dvpyjX8C zsyKZb%%LS9{7xMu*@+V!N)UXYIHBdJ5`gl-1&5-x@eElhRO=Gs1#asF1+#`c9c5T; zuEbE}D{+FwfoBl*iBmy=l36shnJzqpxJ<6Pi0;PJCZ6)vXL5pY4mz-;RAG0njbd5h z20C1$kf%d}S(SMHCo1kX!H#j1xN^!)4MwFF0@|nGO#~f~}`FINXe0 zCE34(9jnkBTu(@N`Ffk;a-NQT0`8%-$_9ZcpgvhwDdw<2E`b%r*JNqX?!ZOb^ywXk z7`D1zNtoDm3Vk3l*tPf}kqS8urQ*|~bcdG0Ey1g2tTy9s6UhyDpFu@z_(Mwam@h~# z2!f=U_$0t~)IVtSnLAW7D+e;BqRK_2&CHPED?%@0T43BKoQ&{R9wnOA)e*C*(fSjUZq5^>OJq^asGj#Rf|p{x zq@b>-Iqrurltji#@AQay_K^8hbJjeDj3_D**l6nwmYGmI`%3co6r#QV03ka_w+6QO zjOm8&?KpFn`!!j{W(Zr>jLfE+z!I=6FSwnL_p~~_Finvemz%Mq>M=ZIojQN=^*Br6mvlyrfwKu&updScXG<=a)m0nx5 zg%ioa!_m$K_`2OM2+7_3U zR4Y12CLt_F5DJa>m3d&a!!b4*o7tdD+ria?$bHcw{3am=8PJ+sm)wS~@bNyNKg7Et zy1-cB%}u=qAqr@{c1!HZ&q;}*vh*We{{ROPm9RjzyYfOHWTW0@&tsWz8?Gjs$x&Eo zn_iffj>J_fIg~c9i-MrT$ttYhNa`;#u{5trvjpA@K@?>=krAt#7nhi(%H&#yV=U#E zd$wGIQtK=O{po~78!p0HmGCRHEJr(rEKr9nk>ls=%4&o%1<>_o5zcDzvh-5yvEY|l z%I6z}g&y#0nLA|#6R%im=Zxy2&oOmNhy#1M*?$ZR8VVNr~x7+>r{Ed$iqhRr%}Q3X{F56m?hc9~DOekO&Eh!$?N>dVDvP9h~~Xn!*S z(LSuStCjX44jDoOu;vvnM$FLiF<*I^XI}1TG~|hWlgYS;w{j{G-U$x$Orr8}7Uyw3 zVrEno9Y+yuU!|Iiy5)9YI?SJz3S3+SZetFcLk;c;o}%=5nN;;CzLABZJ+j@I?vobR zoVZ}LRo8?OlU3qZE-n*xgK<%!{t+UEr%T1U8_b$Bs`fsz_6l}`q0u#vn&9sX`2}$r zqUo0?xYoR;S$-VptAM&n+Uw#AOg)*I8nL*{a09fY_m|RIwvRnmfA&-vFI_U!Wp+5I zg5YO{PV(3Qwv#z|5oTQk8+A^Epl0uK=&~Rd=4d6X#zs_XWQC&FNEy{e;k*F9Y~1Mw zZNBkT4|H52jjT1m;+U<~_c2jtSANl@c>6J$w9n+2q^=jZ5xIxaf^b-F;m#zvlY5ck?EZ63bEZGb)R=hrIb35W1VhRV+1Y4 z&!w(b3u0MhsqoR71oB9*Y4d?`n-=F9n}esy9x7e#S~xlpxVkZO1ySZB<2Yn%1I5w= z#}jM;#4`T?aRs6)OQIZ3FSUXnZNvjjazVBjlW;Z}Veq)Ov6~0dkBE2 zd9k)oc4YKl%)rpl!&>hOdP*(JJQKRP%yPsni<*q&mSed>#%6;9(y6&*RYM19muY0) zqK4lBW#CL?F@9Tt9`&49aVk%d6wy|>Q(br$hgsCNul6h(tVgO?)Z{tQ>nYv=&-Uhl zN-MmT)tfd-KqMP+74Py)6U4dvqX%5g^_ME|U_sGON8V<`@i2E15ivVGHq5oY(JyPk z4K}zOw=+fG4(0({SDT(XNtOWGZoJEz!OFKXioBwsPzlPz3`T%)FkL?LF0WpUq2yre z;#%f*Mco}xx2rL^j05Lnu37^cffp6AIMg+zFDIPG>Q!g}XP2m|}No(YZ&^rDB271o#ug7RL$1K}r|J%va3R z&5vnh&e^&LQ*%jn0t*X2RHMKgvn7^wyphQ`Ic5iIq_0zm=sW5&G9!F9bu==VM4VdlDlAhL{hm(t6n zI8VU`>@?Ka6^Y@Ox$OA}&4wN#}}F*zeHq8fw(w=M$!=O{4B4_0Do z<<|?C0>1MKy|!VN)n$51D`|ntK3vA?#qt~MnG85wcY%184wnMU-Q}#GsKqa5#lX-X znqh9!fY@SJFXIZ@#tc&1&CDp>Ej_0!ZzFjdm^cD8}uU-koI-<0^&Cp(inExguc1%^+aA2jl}@k=bLM6WZzdFg%W(RW zVK*|mz6kpau=VVj3S^6`-jz_pxtO-p&@}XjTdy;teLV@mjECEhZZWsZhausXf#6H9+Km%&#X&%MWx+BC z_(0S3N;}w8MXIBCcHqTJutbDVvl6&mI|xH(!4(=)+Erhg;$`!VO*M6{B|3NAl}hB` zWq|?4pj6>Ap>Ki>aaS9c7@mr!%)VRU1?an6KGOFW#%9ti7V=Dn)-Qyn3I;QWw0;nE zDR?}AFoo`~oN&GKWCkMjPwE=;;O`I=+w3EaNnRpeLOQL=C~3ttP<>6{VEal#2THT} zpU{r1X6iFs$g@GzH;PCsOZFxpZfQk^n6qYT!@Vz<^ivy@0H;nPva5fBE`!c!TZ<0b z+)}H25Fh8t6`mg0`9xmhsm981vmK+Y|AfFSJq!!I+5MUO~uUjrSU49l2N6(S2- zTwbeI7I0K7QahmukCY1%C;=;y5l~Asm?6!;uGKc$6ACZ7WSi@>xN;gxrcQ<_8XmEL zRrZa27H;9<4wS$Jtnn2GzjFsA*5)avbgj_4#M>hDsJC39@mEqiE3$#3gW?cnPGOaC z)tZL<0<^5dmP3hz%HgYn6_6{m>sP@j8bR6yE2i*M^9&p0240Z@jmU#$j@?6%Xmo@H z+}&h?u9LcCi%@9Ms(?{W=4IO(cQGsnZ!2@=W~#MS*Xk?1W^=yj_ZDRpB1AS$ikfHx z-3CcX(k&WFVT>0k6b$>Y2~f(VN+tP&ZbRT9a!{d@-PE{JSw-0qJo%>~q$l;nDW z3h+4P6aixXVF_D~Rq>R>4iBs@;~kvZRN7%$7=*sbt0wdrZ6Mqn*xhN2iuhuZ`GGcZ z4M|C$C|1`OrEbvESisV?MMd$)STdR-xiEGj*gN-wTFJOwR$N3}gb<>~tU>Ey+->wY zmdmi~6uoT1N|x3Vk*j95Oq32VY8bqjnXd7;CL6Hn9nuut8;B~=;TKTGZVtYyhz!~= z>U@li`ROy->v_<{dFlkjpamXKtqU_rpb7?icZR|*x`j1JisO=`o`*Ei3Lu3Y&Pg)n5>Ucg2*qmc02B|r zu3yq9Ko0gyxNF1}hq@~eEDdShB1NN^I01NwvC8UJ&Z&D12~1YYC_)zMXd?wIE3}n* z5fNrcV~FPg1C1tXUUxxDVq>f~W2oAjGBV1c9`cm4oGmV@8op56tvMkB0mIf<*um0N zU?rH$5!wPEk}(GLmx*38YY@{VdPSU@>re++*D#LbQ?ModDhU{<{bpvKcsyFggl9;90+GmLGLrSc5MP-D`e)s}A+ zdRdw~h^@ynz9vzr{{S?!bWt&<4hSuUa~&=9g35)-Z3_6&rdc+tQ?#@jaB6wV&CMVj z@ytlVn#4h8EioJH>IZ(*xpQ9foI>kxWC;cE>X~Xs23gt?PGf+pZxE(vo2V-u7co;{ zI{xzP$UJ$ODtHTV!R%_u)BSAxQvPqmQ-0l z&2tpOc}68g>HL|nK~Zw_HOzcB6(~mErp%AREWFDiE#4Re!@k*8SwrGmr=C3;#oX(d zy<=cXVizj!}PL{jOX$8(j$l|8luqmMN)-~oj=3oMF!e34DTyhr;s-^1$SNO}> zJ10=ZWvxv_v!n?t;1T9BaT3xT<$IJYr-yh`Zo&#uQRY_)a0!~MId_6oX|fWvHiWwl zE$}bm5Un+6IcBZpl@%VYQH9{arp1n}p(?ZJ0&#jkDhlNk>gFH=R}*Ts1-4wV0*PUI z;ni`kF&W94ldcddj@``o8{8sSaHy2MVWs=$Mcyk5-XyNIHF{P{@a?6cWdgU|5rYX#12z7*6n;s(J3)IqaQ%Hqs904$xT#~{J zdQ$JC0T**>D@?|-^BcAttY;LLn4yl0aViiaOx!S9j4qpDj%|Z4IL8sObZF*27;_cb z`c5wiiDhR1rRB{A}oP!hgPy$A`< z68?Pdr3yHNixzll8S#cX+1-~>s z=7Q4tTAfVy?)v%UH_B&WRm3FwhoP+7GaRSTPOKJpc<+u0Ytl=k97FfO7Ww5ehO3R@ z=Q-Qmq_zyYrEvRS&(WVf3(Kv_Y%2cYK{Nt}4M{Vfng;-iBNIv<=EI`HrP*uPY0cCi z3@I4wY4~hM6H^7E7aTdTiLnP?fR8sTq07^a!lTN|dc{X4FxpYbG1Pd4HHY3%sh4Em zVfH^--~H`#!;W7xPnMj{0}}ILW1<&>&0~k;T+$|V3w%E>eG!FGwIJ!)=%!sY)vqOG z74E7c_@<6z0fBP2T+mbRZM&E+FW6xWG~P4g7CeYZ+IvPZuBAehe&FKgvH4SwCroSHQ#;Ux@3>0{x%=;)@kP;n5YPJBObmQS zYwJy(MkT-+&g9bEZH3Yuw$%MH4SssPdWraqyn)OJ^2Tnvp=He#rqan}ANB@}(VAP- zWZZ~GoEF+L#fhF@!;vGIbi()#EM$JrHBY%>J=IzhR^vS!Zo`phV6gSDVkFsNAt~$+ zNs!TE!XiCvl>Y#79F%ufv=v3dr7-xM%V7oZ%ix+gZ)X54pXS|(ogyj~sYHW&x|~IOe{e*J%405BWZ4c(FD%^{;7F^XqMV%D7VEw@q9JXJ1LqM;iw47dP(Z z+g!_X@s{4-0}vOtFlKp1Rg}{gD7xNubZ2`Kh=)(M<^n?kR|E%(t7-tOa4Xzq~&BCWffEw~klh%G{mBO585 ztjMsB@e&jOiL0x{hh%i;Q|B#;o}sy)8K>WM)ngciG%|nXfkkq}Zlc!Ul-l8{OhYjTSe8Jwb-@U`4JY8l9mX@B(|}Wcl7`=Nhj-(f5Er-GgyU+zU3EgjawZ${ zDC(zXf;%YE806+hv*tY95+mHgDLexYJB29Ac=cM6Z*xq&)<$ zS~s*qFcU)_s(Yxn;TLM zPg?S`0pz*3?@kJ2&!#idp2Bd4RFOc;&&rb-i`08chv)~8saU7}Ea zaX^9p+5iXv0|5a)5J;&HWFUmYTo~C1$RRE(9f=c_FPVBMo)(61RD&4M(9qTp@R0~5 ziGxSD?4M#o64;3`Nkv7o^m$?tplXn?VgBQ^G(m`j_+}umnQ;k+p(jHpF zCVdw9^_EYv5QKxu9xfQim+&?>=;s8_;ZJyXic;ZW5N5=fkv{Oc;GaUo`W7B2hIzx@ zELig!77`ORnl{8nBs9i`oFyR{26`vL;W$3V$Cz9~A`z@>vS(Q6+4w=7iEx6fyJTWK zevb_4>}+onC`)5`>|qlShegC57bTU7J>uf*LmKdTc)0T@r5oVzhl`Xvp*}k%DDbt8 zxX%Q!JQqUY;W$Ilaank!E9kh^6A>xUxS26THYb&zb2Bk8jg(ni;)v0qQRuXVXm5vI zmkVU!F*AAGT;6GoTxc<+gJSQ8c~6QU^Et+ac(<4>VvXRQ88{)X3yLlf#i69(M06)5 zmR=JSZg+*kSa_&JEK!}1nQoSqaa>Ta?GwWo@W-C;@dS%9%EhpqDuNh_B}BauaE#Tj$}iKQod{W2u%*<(?0XUt4wOUV1jOk$sMz}!I_oT=%oY}D(fD^f$c5G(1dRX`-e7I-n1tp+8BVkB9h%Ii zMu&(x$RLFW-p9OXNJ6C-B&jm6hBK`6UNF%_CfNzWCP~?GaYwrriY@j#ENWVd1FL-&)@J~DTvlkgeE-{ z-dDn36w1or))qX)G|;hOvS?yN!{-aae647L#3m7D#LuB5w#KM>7RfM(bZUGQ@8g9M z;YK`?AB0)iXgD}@OXTE_&eYsW#MSpVppO=We-G$U6<&%Si1iJ0p?)| z7r-H4;NM8T6YQQLlu(RC8ykake+OLFD3F5mF^`9W5Y;^mvcsoDzF{{?F+*+ee5O|g z!3a^Yh%>BQq{$T>MH3TvQ|*ns%)~3p7d6E!F$h6W#^}#DC3_H;-U!kXJ&pQEMQk1x zhuHZ0#?0uvJH*Aq4-MnD2uQ`Ju`P(u7{WhvMIsX}i>PRG4Z~rvvh(+YTsp%<%7Jt| zb}vI|u85V1rx|I9uvPy63id4_=uYTL7kGZN)=xsS0)~^w@7>ZwAG^i&&U^i0$Tc zNxp=Q!KK-Y^^_ECiMP@$#PbG+g;;3l+9cSM?DS6bHZvjnhQyJsP@LCde?q9vQHWC# zvL{1t6gJg1#I=O(*y~s%mWM``i81R2(fGpCY)|1yv9c5N8>A{kTNQl|*#yv)GE&ku zS{puvOJY({5&4H#M`uNEMD+fTa8Q|5+17-F;VIa&a+s~L0fvv!H_**&H_=P6B=#+> z3H=O4k7<6$yg#F9^nIAp7NNO=S8UOYqwLl^z3~X(7Y8xn5rX~=b zLQcgpj7o)wMkW~Bu|ge9i9X0-cVnTan4;M-mxRI?+$$2K*qg9yddVglD73aM?+2l6 z3Y{K=iPV|ak9K}w;_Pf`c0}8Fu>ZsW9}xfp0RRI50RaI400000000015fC6DF+ovb zaWMbd00;pB0RcY{WF}EF(Q6(|gg=8v!V`MOla$m?_&yZ6vI!8MyAY%p(KYSjxl*3- z^n}&I5R*Lr0D~IG%JRG}TeMoS5MGc&9!gG6COlkkC#a>xi?U)wdp&(oL^~$;YA%9u0L#Mj;3^HFp01#_OUI7e`}xGQzX^6HHH`6uBP;L_w+B z8xs)F`saBfQ6Y3vov|o3H}G9!JP^4`TtY4(u?vc{9A%G@JFbm<1Yrp>`fOgvGMzj-AKCo5}aW% zj9XDhPN8(mQA#l(^eVAyV|-rpF=9g(q7xROW6zk7(Kt073)Am-hCc)@nAr(p9(k;T z*`8$(hTw+g8T{d4Mlmg2jBlWv#x**~UWEE7+o3zgc!Lpjlzog$y51xql>Y!keb&X7 zCRjR5U+Dh;)3QH;W9&j|7Vx>~{m)2F6OtEfTTfp6=JB;e*TMFNqRW(HeGgG~iI;;J ze*`1vICfF7XTiS6Vn;&l4@LbGD;H(RMF5F%0yDNNut|r$cQIu~aC?HCIA(#)m|g z!F02+vAlKg?~i3JvW@qJzYj@_Na+>)BV6v8OS(aI zi-8>vi)Z_YCQ-5H(iR?#R5PLH(PBH!#8@|DNBo9kAfNZ86dbZ@R>PWT?iyJNF+ij zmr_I)e!+=SFMMxf5bdyNe;K_p0Zud#c^T)2kDx+Z)lH^IH4V;M>; z>^3rD8-?I`*4{mACSDX{6YNhvgAI=ne@lEOKC!aV2wQjW6~1KR+4Q9#g^V^Xm7dAb z??UL5ZIFb-(7%{`5SKfsdkjN93(Q`OeTgNFiMBNbT?#SK)~@in8pqP&wedy9D|kU- z#puRUu@mi|)F=2i^@#^ch9N2>E`&D#(?Bf0!Jgvc$HBZ9OiV;SybSZ$lFBpaNJ0?_ zAqaxtn6u1(@NNu$!7lh7?aqa4iJfO(NI^f+_X=JM7Ey6{^6+E0ct}DPw+LUT>^FEw zyDKX^JA&|9x;z)Z^KNYsG7S^y1g3a4T0bdM=t{1Hi_)RLtZH~SbIe5+A?5!7i_3|R z7}XjOv9q$9gH3cq6Kv@-@f|UFmm0)Ao^m0g<;qdfWfLz3BoY+SK8R+Sd-SW(vHT&m zh=1wV(h!Ga#4mBNI(S_0Tj=$RYC>fj>!PS9d?f9jc3#MSjyxMNvNn9^=&{ea;HAo1 zUh-`0U6+f*F$||fE`?>Je+T?O%xw&MQLgfZ_&|n`lqpJFu^+a5;Ln6VoR+@3Vt2y- z00@xl!KA$;VV>J_(FW!r^^Xw9&cq>~;F)_VPO|ZidO_HnLtTx*j^Cum?EW1OI}$1p z`=7$nw8>7-=txfxn#T4E8-jFbiklpIH3-+yd<<;4OZ$Y7+}g6@*j)Qaj?WFzlvu_4 zo)nu#QX~BsNSL0wXlU#{x?*r+dgsAkxfwqMnLF&XpAQg~$AvCY2&IQq`G{h+yBqEv z38A{^u8ALkTM%PY!Zt4YV@-61q7d{qVuNcGq8$k)AejYbL$ZwI)ncJvYdyijZt~`9(;L}yjF84_2HB; z_+qD|7E8Jb&vQ;!#9PalHc{5+F1m}&;#qm-302IoxCracb1fXR+!Ci7 zj~m~JE}Xm%&aVt5OxU&nnZPiSf-A1O~ah%3qbL}^tL@e%Swvbdv|X?f2OglYoH z#}>+7Pn0ujWtVMoOq%*&&SQ#q#HHK$fo~_7VzW`DUDR~qG@`z4Y@#Yy@|?w4TH`az zBnlUU$qIg_U3Ds+nD6bMp)osXAHf?$G)X0 zHx<3UqEUM#+LtQazGdGshZzq0iBC;SIK*EtQGRn1gUrU;?q8`(3b;@(2B?7Y9wK;p zVmq!TIt}v}S<`WH{$=<)%e}7K%vFt%txfzw>EAapyWPa@-*M60$d|cm?}#gJ+(p@l zo?FBMrK+Xi+Z8@A$Jkdeu>SxsTC0FFmevg&pNP9JskOp{7Q!<>a+SM1;ug6ntL89D ztq^N2f!~6B~XQKRh6%}Em|p*+kD0x96_@Amqu;|nNlCgefER}9G zb>?OMVsz9KbZ%I6^A|seuBub&qnyQVTE6B}`hZN=mL{;%tw#km5qkYfd;|NQ&7}a(!YI%O_r8N1Bb>mG>7}V^HEYF>c9Yp68QN zlQ%VW`+?2p^%hl!@i6J0C2OOiqnEUGboh#mIgO>PT*@|L(u>I^DxEV1GLNV*ioINP zXf)i`g})No{A>VUa(qV6F^BgVtROXfLad`%m$pI#(b2e}qr^5P)?>ajvV*kOx+<_c>*NETK90ImUb6ksw;T~93AuW(BKW#){(#Lxco7Qf^g zd+J+ZPcu$xD`ya!ym?~KbIcNV2h61RD>|0#{7NrN+$d$m@iQ>|@fNXl<|?(%nPfTY zSkK`vc)xSg9mhMXnKa&d{{Z89`>E*5Jlq8SX2&xjjQJCJC&;?PlY$>)dc{3Q2dN zi4lZS^vkfiHuVQbFS$i?gny*lv>Is-Fl92VG0DJi>Rj!k7c`I@lDyT~bcw60bsL3d zyD^|tb8z=lwx(5lDpBJ7%1kA78|p68RBsVr)ngED>_rCoV5N#bb52Q)1eb-I%zeIK zZqi=bto=gXI+T9m=YKvWydj4;Cw2IX$4+8s;$D2QQ+oFWtwEy)5Na#g_XKfu?tOZW zTYqzq?oyhlmyJ)jj+oU|O*PLDU*0taK@o_#& z+&0DC#&D&|f;fu?8{#B+SR;i?y`>>|i#V6r#0@4qW@Dyg7^{Z`EY%Xt<6OWVY3fqQ?Es0Kq7=#WFFiX9hqXLCmh(JNd zb0SNt;$Segq0qPS8jx|yER9z*^(~AQ^*2|!%q-EEoVNRcrx(s;CMo#m{{V|VVBG5$aEB?%%5pN`~5CvwR)Zzz+7@DFKdGiIj&zZIbFC4=;%63Wz$uJK+&X*CaM_Fsj zIlB6TwmXhWhSHK({u5IKHCbnuiDR3FV>iUV%v&|`N*qfc5Cf}}tVOmp#1wJ$Q-_s< z=^_-~?{5%$W5l7{qV3-`DR=vZf}Y^yNv`2!Ys5=XI6qTb(KsIjK%;ER z`9#oJZK&-?tiQXNE1OAe zu_>;hn*2hOFB4y}ELB&CwNAN#xX~2qQC)h9c;;biHSQ)E`8k+9a!M_0i-V__r9NUE z?hItlQ9)G;<Qu>x zJVSMO@yrzw^IXK@j>#T{$Vw*Gv_&*gW6Bq?IIgC0d*C6K7VW@-~h8tCYI3i7ESJU~IjmqZ` zHB(g=JzFVV*PDjh*HNWI9z5;=rHC)S;%d6;Hx+xwak1uagbnkbi0|1G5L?6C44=5& z(Ra37M3yh}2OGiSUU8e27PBqk>L~bS`6^y3_ZOq$$d`S?dDGOVtR5u@<2#i904%F} z<`7)sHByUfFtZLx=C3oNU^Omrh)1KuD_J*dXK~wuK6eC$CCxZx2`8mD-j7+%&chI!7qIN|$?tz}!YgI%mNof|eyWmNm*|+@+#0*}glA zLaa)y%2%0T_=cYw%v5evR|KYDrcAUOge8HfUtnF!l-OJy%Fb0%w-$!67mlMKpoyZ# zih)~`$51iazYw9}PT-dZs0|tIm3XXqin{#Dr^Fj?FSxj`%y>pXrx<{8%Fbf)pNUrM zhzwccWX`6|W7z{c@pF>?=6M+1U011izfDWXnwR34bh{p<{{VS|b$rd3Di*q(z)xID zL9_XujWI2c#K$IUaiNnEvUBP#Mm#+GnJ7%nLW|br?TWyBGmrNjcFPNPqKcE0CNOUF zox!;ozcHGqbg&5`X)`V$Zup9|tk+Q1aNI&?TyqGbxwJ%4P};!xmYWf1rNHnL6wN$F zidBUi#IhjOK}A$8ILyd}CLEJ0;I;~uW(IEy`4XHCM)b0rFW71ZYZq;Yx{QNv9IyWX zWyh;3%RXEi2BE^`GBsSyhPpOO5m2BSfiz$hRn#ulR6^q7ASz|2Y{3^ryHl8AoolJw zTfn1VSNHgvftgqF6sj7uF+?(?V*5}H!>E^{wTFaR2qO&g%qett+@a|gT7vJ1Yb@NQ zK0B7iU{l4e7_4G7mIm;Wi!VGyjbHaDzWS+EHHmeKZc_r=0bP-ISn4=xi<-G_WD%{9ny-ja zc1!{ZOKO~XAzi_|7aIt<@c4tNq$QMfA4}vg!6#O8l zO;!Vs1maqAz=JDV+zc$%=+sdCAa3%$W6g`|H)EM)_^m*{1977|9XDTUhC5ygeS9s28mt5vk+qPLx_aC@7i~ECFa>cA~ z+|OB#eaE9wII$tl{O)#hneu%_5~^Le#G=8Yi02n~h?zL<5K7OP<1JSHrGs!2L;8X= zf8-0Joa%Km7MD%=hG=<#Egi)$705u80@~C#h2a>Yi1CP-WpWLl2(%SywwDwG$>J4v zp3&K z=wHZ~0giI8e0h`*rc+|NdY8o-t7SbwvJpq<@<582$!Hi<73lkvn!UmQ06$Q}#;uLva zQA*KV)H(yRsssGOOBKCY{6HmbM}xPhhK0W0Qx+>JxUp&QU_tvw#Ip-k7h>o`V$Q97 zU-KQU88-{0tbwVoSg(7oVRH~!)v6elf`<>jrQI291cj2OdZ&~?)(F>ckNX^80I#WO zkJ}LI9dkCUFYzz#;>aee5bCaCSXNeNSW7$GnRm|M<8Rc#&SOpq!i)KVPNA-E)VIuX zFCWBP=MiSJ6&&VWebifb{{UuO+1xu4{{S(T#b^HjW>J166q%L4`wEYjFOnHJGcmyE z`NAxjQ^d zgx(g>_bN=h{{SbbD^VzF1Ja8xB|_zgF3V5%65Esp2ZmT(h}b=OgMo(yvcKbo1kS!V zjnHXmcPI--P8bDQQ(jX1?q{`k zrUR#+)D2V~_Ot3*$93%BUyA-sOAdb;s5;yDFO zpY~o&7=?zkKT%{_3YK-V?s4@006>gHc`7^j?p@&XxURwsQ(1Ey9Xvwjvkdf7)WaXd z^xP`duyercGaDJYv-*K%%OWVa9R8B&{$(zCBC+W$^5$*KU-d=!zGKdt{X=+oms93k zvX#*DEst)lVeTfBvF>pp;uh2Qf`F<7o_KQtf}rKYG5mLN^#F~ZP&xOv^o#5l(pca( zLSKs`S{O1yHfsR&A8_2=tPDV*ws~Jy;!#6Z@)MK#f$GY*SK#=>SNx^n_BG-I`YPRB zCkFec%Mgvpz4Ob`!PP<}X@Z(s7aUA!EP;R}S?RBb#I=%uEPTIza;7hH05kkcgF%(W z<1-e;O0|DdrC_U0%j47^`%!p+aZh|#2`f+>vyD*kL*XV?^+9qV?2fjmq7TrU;$;6Rg&=M?igB1i$J=##<1K< z5`MS!4wYU_DZ|agQsO>nr1Vt}m}bO_SQ_xQpdqZdHemXt`ju2(`%o9`BFZa>vx9E# zDtml~0qK7PT%ou!zANN`GRHA7kmm38F9X82Fa0IZF9P{hr|wo>ZC@VZ3OCa44F}CY zXE!-`d~@~c;k zkARR?)-u}PKPcIVISog~AB#TxTpf0!s;7K^`s zIQoU+-N;|r)z(9j_YQM_H*+r2zM-jv-UZp`C##p-lKEI)u@T<;h08%p29e|zuYuHI zFziqVP1BE11rAkikHH)9Y&=moVbfQ@h@L#mc}<4L_Q~!&RS9w4i}=@L5JdsR5wEpH zj4zd%!+%zIgE@}sdU>S!sa}eZ)wA!-Km(-Oc!l$Z5%D$PFLL=|-!Z`}(HTJ@%(pm$Rz}zITL&Uw6a1(oH2H=n#pd`cERJ5K4 zCKO6{g&4#*1+mz_bg#H^Y4WG{F0JmN{6EBYY0ajW?3Khi596{Bp}%B)C2*oH-b&aH zb4Vz@cod(OJCtQI?@8mrOWF?SNK|GBb)pZq+ zEMZGve@`)&sZ6_|P74e0vAQ6!gTQ$JeG#lZL5%9d;7%Qc+%oy1rY_x!}n70btbY>7z_7)F=ze{n=|=s8ti+{#r1&h>??JRiyg7me0v900ncqI;Qu zPHPT%mhS8*W2nk(QjPI4tVy<@>`$mpaoGoxGnswE@zkZy>KC1r%rU0%nA_D>VBZmG z%-eaJ+dey}3|-24{7t>{7ulE<%%$bK?mBbb_ZKFUqOGV^Vaqd=BfUb*;&DiVp0AkWAem`iV)EI+%5u*t#lD7QZ0XDQi1 z?!TzhitGy0;<$!zfV57K@z~7XjVGK%6#BVpXJR{ItQQ$xL%1~yL(A8I zl~X?^UPu5`DXPM&zx|f0xu#hMtl}cEyw_vvzYxJF$@7H>)VEv?Po;`+-ey+(GW^YM z87!zcZk{@o1h^F|`Bi;J=!F5%;Sd3YT#7$}R=$&+Bd^u{!U9j_T7IS0{Z~!r^r-bfcmDto$G}U- z0fX+CN8C#-Y~u8H1qPRsThIohHqJoP^C&Oz2>l-7W&==pIhg2%fVa=hM6N{@>c4JS zp{dn1Db2O~%i3bmc;a%Z8|ln=ZDtEh_cM67Roo7*;yNL`{%#BMdX=2U>gD@B;=Pr7 zin(`^8II)-;c+2es_PE zTCCWFr+ks@%H!myIQCSdSlt%B^`W3ZSraL zFM64-Y!&|h01$5n(X}%BaTpOORgk`VWs!kU;V?!vDfvNFF;b$o;}=(^pl38tlRNlL z^(mPKH_W;@r_5RaWx<_O?~l2H#s&fiKbhZ~#kPFlK0HTASk+RR_zUv@HkDk|UmIfU zG{vot=k+tSwyMf9))bYBJd0!Sk&0&yX?YK2{ld`14O2^BFYn?U6I$%0(et0OJ77(L z$X@|@g0-ZzU_M)~h!)e$HG4<`Cmms;VDyTo?<+VOYtFSLYZZsmNUr z(%R3=vh~JEM%a%->Z2$v_Y0M{V8P{-E-#2}9nN>S8EEr1yvqShTRP_F${{>?lqllj zpVYfQ{DLOfv^kY7S)(4M4J1X%%iO&IJGJ~ra*LQO$5m-83dh4Y_P#s}d zDAMu%+;On>c1hlgs zU_#1b5|g8nYd|*?8eaoa-DPxB0X^mq6JV$3p^aFB+KS7bW>v)G=2TB!~mFS|XEKpr^^q z1gO4O+%Q|WZc|SJ4Xdu96pR6iYGsjFgpY6&TGVBHKbVCN=22P1R5J`OUg2^)#Kg}> z+{IjKYE|MZe7(byzjCxL2Q9r+JfVY3OINepCqHn|Xm~!4`i9ar zQumy_!ddcEE3M*VIbbKUo`1NcX&lKq+4$}xc1i{FSD9cfMfB&KpGzpz-Xe;~1Dfsx zX~=3m7u+($VCbyUrhQxuUJ_B&SjS3#;^wXS0qPTF!BXCt&S^)dcOSx-fh`GpXrDf1 zg_o=qOUH^NauGAlAD~anV=WXWY$NP`Ai#}(-TwgWR;GjM&&hEspc^?mJvZ?TG;so! zH^Ni`P$z3Ue$25#a^21>em@8u%RHb~{A=Pc+8I@fe`oGmBo-xWhri}0y(lW*tub-0 zeUAYRp=vqH1)tPmQjybcf8L{NIB!2*B{b*H3ty%BhTx&{+_w9~$<0P9MyUNkhNsem z(08}riPmgDJ>Ka2`IW>&UpF5{%xN{zYP3cC+(5Z*X>+c$Uvl*!7eKGfthwMNCOPQw zGSJe(3iRHIMp0zKy~1O`tApRW+4<=YRi4*vj&bwsRX%Jg{ohoa0m z;h^{2sQ{;^{#>p{LizfP?D=DKkl4Mz#{}bl62fvp@$mrh10SdhJU|cPpdS&d0*4d) zVmrs?WhwZN!WUVqh)D1omRHW>yAO3MOGmO;ST4wV=<#fLtKn4yE&5$eENd`4S2u;g z`K`1HX?Yg&QJE;ZP_H;ORwp$BfnT^aENGtvU^V4fxH^BNvR(NQe!rPvEq$MZ`j*D4 zkb|!^0)1`}>Be5AKP#^(Mwa`8$_Yh0hpPQfv2;*pX;NzuZ5^g4rqX^{iC~~A?>xO! z?TmvzI!sBSLF?mHI&0cp;9uOS3VDXzTlxttva_Zafy%Sf+;1T5oRw+h3H(aPE7(6U z;jwPYj}0%$8I`bM4j*|vkk7mvO1JZ+pVZiy`Na?X_j3uUc*U;I#YBPI(UTXO)9PJ4 z7t)r|$+89DP}L=g-9JVrVxf+!{{UD(1RATz17q_5Xp!d&0@s3}MJh##3_Vjx_>1QB z!KiFklhn^O&r@i+RQQg7%RUl|c?2)GZz5cr#sU8TW(X0uzK(V7j^lL%<#xII_^5^9 zzG3IZ{vvLSLcr_feo1&`6M!zt>j%*s%Q0JWFPi#+FfOl)D?#TGRHzv-cz;rjky((q z2Sazn3X*N6&Ade8xL4_#!iUMwlrX)I%%TE~-9^rsXXA9ZS_bM^^!ne(F3xg7@ zWY_yJx|^_}dI+h?cZGLn617-832SPY)4DjBSN99nccXIAjYmt&TJ%MGNuVW?=(jOJ zmb^rk`GaO0RQ~`}%G-oPQD^LxE3$R2Y3sP*ZZC*#S9y#qzfp;{5Uh4?_QK!YQy8Fw z1^GmQF4pVfW-J`lu~@XcIw0reimxu9UIm1>{{VV`;hDQ%so)7;$znMRXoJ%2^@yvY z>E_#~>Qr79kZ%}&0qzYGL^JzXN9BxmLfX{u=RD-{5EM4|uLtfqC{7H8VErCa2#G~8En(n5i*3N$HhQnjK+#sJ z4U>jPu#K_~mp1b2BgDaCY-xR3SIZVqxhN0=s;_N9dR5dtjb9CLy;g{X{E@xv~0%D<4p&^#YX?Z%nsVxyKUIuQ9RBJBV@a8Rl4a z%JG=T?pDOhzc5yQmC>+T&6*#> z5;$+1mCAs<;H$F@3}lLri;X|+L9T-CU_o9L?5NpRH=&}uZ}kpTR^&^kujU^Bj4h|X zd8jZ(mx_$NV$1U>-n`Cf;2kf^8-^0v=+)$+b|vNzE+q=+z8aN3oNAzrxOHDK$A+$g z)IQN0pFC&??E08PaU}z%9*B+*A9mh;{{U#+(=eBtiLf4bzAMx^Z1PF^6SzH3Y0SEZ$fAgbp$w;EEzezKTJMywwuq zGA0<-I`=a|k-2w{TEP{xI8tZBKg4E{*cTe=(#82kn$2DSl++tj#Kamp3%QpPfb$%* z)j)XP7agK|&p4iJ>Qz{YUI#MJcs#GB$H&%_5Pg{EgjCI_w{Tl~NQh}(CM5q^GG z64K{k#B$~YV8VrV%5+rD_wFso#nDjROQ#RSr9k8Mcj_795s#kzMg$d>J`N@3>e^6v zs%`Zh1u{ZA!-asj5XHB(l#xklIFDdee3M3V;&Fw~>T269XW^jz$}#2PuO&E6#7Lu^ zPQ{|eFN-Su#c-;NQ0(^Zz9j;W%2N)(vE-KmZjkI}u%Rx>VU2=_GmIuU`%95G)HSRH z^5vU0DG5V}rE_5;Kb zNWN8cd^EV{<~n%h$Ut?SI=D4>nScT1p}$0Tf{ewZ9BW&IHbT!JAC^94B{ZiZxaH!Q zF@bZc7+Z!pA?31-9IFpjol23N6%^h_Yp+pRiVdMuEo}ZG)0>9IKI7t%oeH<<2cYXB zXrqjG;s(S4E)FZa;rWNHXy?!^eZD5vnNOn=iA23sI69&Yv8sp&IXPw}yGoQ}d=#KLiXkT{c?i@KTN{g($MNJzDFr;W;`i;USZEN)@r6MaOdxm)9 zL~0m6xj{SX95r6zVT>CVpIq@#=fOZ%h_0#(^9>_XDgCA+}tiAup4;q(F4pHSzjnh z7jUIM*<}|($K(CMcR?up9N&maPe^I&ZOhf^Sq`dE6ak1l+E=J!b0K+Vk_!67L+mlF zHc%>Rn>0n_TLIP#RR+(j+JWG9t^f#dj*NLgg;To#vv?~mhExz)qGgklc8kPs#JL0T z)^X&s9YSG2ShI%|*BF%g!fB~!_#01iAg0|=G6+5l@R4`W6~%Gvap5L=B+pAF^4*^h zmsXi5OTRt#LSZ7FGwop->R$lvC~Z9^6H)jX#Y%%(QAFn+CO=eN0-c_+?jVT(#KE@R z;{qd-t2Ugzcogw!dM~MBueW1OkBTeHE;B^y zL%F*K1r;}v6|;gca4zo(f7xjvYYu_sqn_~>Y_0Dw=tVjJptDCWsLD^a=xKnl-eS1~ zx|qAZq9`I@FZ^OzXzBCc1+ChM*#V~JTN-d5FuYkbAB%%)qOm#d0XmrlW}U<74FG{zgK)5L)#e>Gs)dn7*Yj1;jOoErr#A{IEbv zgPs2XQmrXjpA{*ysR|YR@0R28$@eCG|k$*PP^O;5EB|gJrZ@ZagbO~>QjeGsTGtL#h4W3$wERM)m zZHMiNW~@*zg01mPXbn?>2CHYnU0UsPUj*a2uHq@kX!$kpuLd7(*c#@|ha0tk4ij=h z*-fd-izyJ|2Dt$ZZQxx=rxbNCuBopvI3161qN@}Sn9JK0h_=Y%cVV@B*Thh)$OBgt z;KlBnTSZj7QQ#NlmOR_M5;FSuX84Jwsx0}QRKRku zQ-Q+PsTz71r#gXPJ3D7?SGqWb!N!YqeXAL$zbB$GV1JyiD8|s&N89$c{kfuzUF;^qkss)>;SQWTbs@ z`8Q8haOE^GDZmqh$?+%z^yTpWk&1vF0}gn9GPr~_6X|r8_Dp$3)V!M4S>gh^R@{da zdUI$-VQ;+%pI&Ti{(}ITNTm zj>zp(D5H_F>Je5z*L<5!k-}mww8O1kbQv4yt|fj5m1Bak(^PtZcaRj#Nq{1kv$hq; zx0Py*?uNMzr4=mo+PVqs&lfcE=ullJ%a^D`LR2R!)+~ZM5b1r?P;}gt+I|gKe^7$B znE_?v%X8GFMXq*%rWVTd)B`yTmR9z!_XjEy!ov#Vb#NpGj60h=)%li-=orT${(j}H zR5nZ~voGM#q0AN5{Yw7C1g{oL6Ol;p1c+YJ#!Kb=OjAayXRZ5>Jk`mc45cZ$e=Qc= z0X2R@KAgjaidH?Y;Ml2@r{np6&46w(^#Zn~6v%s&gbLtnscdA)QvS!~Xzr!0Uo8(J!iInKamq6rtvn=eTx^1B=ONPAL3Nkft7tQ}WR;h3UWh z!|pZ-N-oIG3a_e`b>0v@D&9GZO=Py$haTyI1fc9bgMRlF4L*@mORua$%^Tj*xcOzd z>TX9!Z@>y{B1h}5`5a{A^S z*1=%2N;84_jkrMD&?az{kkw!SybF-(^{bfG0bq}?=Ck>jQYIC^9%u%7QsndykaB6u zCzzy#sw z;dG-gSzTFuMZgNyEY>cZ;sC6x72H+eN%sF7H!*Zv4;;sl0JWoncnGBsd zhFsw}V&sJZhnJy@goJxyhToW(Nx@WpqBWyff$D&1@d3mw5wwBl+^blp1b7_~7qexb z`6|J%-27dU!NvzyD^uq)A{kkcc||sR{X;+`6lvobI{eK=sn4ExM&i6$od<8+6RqAd zaeagFDtH#fuVHAe$#_r3dtknC7@E5hzVdxdWD&UZuh-m268?orVW?#D2S3#zrh+9cQv&-udQ(r@1D+noV@u;Hd zTuX6cuZAm7Xl*#9zk7^oEvbViCMaePoRmc3GRfpxE8Hr(%*8CBg3xfKu0yM*PajcW ziJCdjg*fJzBKYKsig-IUq>L$FJ0;-pI(Ah5D+czO8J*y zn}b6DbhVSY@kdNwB^CIH*7X*f4lzo8B`T?5gn}H8G314F4N9JLD%?UMtm%W3Ja0lNs~)G~fhRndp*oA3lHbFw6V8q&Ct2ug^cVYqj%4#EO<%wbD z>BRLB6~oW>IdE~wn0&Z{io`wmRHST7%x7^8X@zg{D)iorKeg`2=P~}YdeIBhVI^p zZ?N3DvS!tI9z63HFEnOvwEUn;!m70bwP>%X8Ud!Oc%UC;K4Fksyn*b^VE92n^2NDo zI6{ls&!$o-Z%ggp#YA}xQMb%!g)~4G?j66* z4Z`=xG3gb0UR~6yEKuECIL>|`P3atR`9Z@_$&gxs>gE;!b{*4Ev^r}G=>F3ng(0^* z^z_DJF16dg^KtsILl!oa_46x*&1X#KOQU0$GLudABcc;l)HrHPo*Zw?NK@Sa)du#6 zK7J8XpAU{6HS~*Ttr3QUg>w_e`t3)B=8MOuF3B!>Wrrv7@cgD(x4L?xo%tT3+b*g4#WCEGdYwtddj@aX&jXvF>WuoGUODp7(Dmcx|Z!^!5Y&_Z>+)ejs{DeRUQeNFDEhW z3-b;ss30%9vG)}PthL;JM z6ynG3IPa$VFIgz$gSr)s+3{<<(K9=RNDP)@uAAIfmw>f^*uU(}X1xLG-FuxumS^Xa zh(&eXftzBEisC#aWyth+$rXoGOJTqqYaZe!exTZ8R7d7Q-b?#}1ItO|AoI+zp$i%E zhsQSsD5s$Dg$}nXZ6ruH##nY{mpmf-$I=qEUQl(e@xGv8ZdKE_%m}v@n6Sn={mK;b zqeBO*t|C}?yKA6okobe);i^XNS2Q~AC&?9$07(Ee}n}4V)g=nP<@7^GA`$k#!2$JoM%%D;>R>e;^PsF1%qh~9ZW^pRxuqjS? z9A{G?nA~A=kYOY4c^7}zQErl8aA0fwP$pnkM|9^Q!a?Ja!Z+Z0rW+a*<8JTIslj1{ z+4gP`l;68t@en7Y%hTGNHkUHVZYl0dqVhPuL0ZD?3J? z(ky?*PBMT&1Ym^_EYA*%Y?&XYf)j~k(n>KfNIKXNbsn)2~uQQwpKiDII3j?dchKjl-Jk_ZdK$bSw1U=hw zw`*hn07*+NB|L-le3Bf&RAHvVI`wXdsBP>*ohaM4EXwIN6jp3=)d-CRq&+jQuQ1WD zRb!S|)gT#aW>l_IDSLRqh$~2HExp%Q2{lfIYrOr>ItbPIX^t}<5gV_d%Pp{FE4`NE z7hdBOC7K2TU_0yP08lwI=33hhipBAAtSZ{m4qJuoK@m7(@a1<{jX2ctXO5x0M$;Rn zf3JzY#zz%@JpD!0Xs3LK-_$t-B{!w8uB#t1zk;-ihcx9?{Gk8>3no*CO^?j8LwL6f zc)u+xYSIo)#rnuiY5ABFHru~}NcTGv%lqSXZp%}?f3snB1i{hS`ocW%;>(s3F8 z0H!;%{55R90gCE7r5%RMz&I#bL8srST`<^YZtETp(zwu}j&u2myx`rl3{0nOn>SpQ zJApE6+pvAk7E*5iQx0J8GK>Mq!NFw^qODJyLNcu$BIveYUY0h|5HSYnlL4FUzVm*dPm z0^k@vC^%VM$LgBo+4zO3D-C>#KP*3W68YT-+dUAwG# zmXeGJ%(qw8F)(MU`!MsIyi`}>+`(y($Z6HS9r1GHGd88p1m{(IsL_ctos^D;{i&;Z zj|Qe0>}W2;y1u1*CYGR_uki(52Rlps<+-tD9t`Wj4Y6xPShIM?$mVggsZh2v72d==AA&jWu_ zy=!_PM~6-zm`g@DadCNJF4o<8#q$Q-4g(v$CVyCt{hdMTst%JXJ;s|gGSAkM0NB+r z>f+5NV5Qe_&`oZ|b@$rhco>Wq?hY5~Te65qEM7fbk%&#Vt-x=hATd`R;A(%0AtjVT z3V<8JC@>!2Rv{mEl*TeO=}mT>~3W zoI8lP?m>rGvwu)B9{Ik`sty!rtMFd^!;r)#exg(sKwDo?e$~5|s1X4LRGI6{BeJW` zk)(dlxnBX&x@Dl_du0!Tf~+8KoFAD;Be^!&FMMCTL0V&O>@jtM@CXE~Rq^cf^$ikH zB=HsEixXI2=DiVVz-=ng4=c)4h{{XNJDpiOCh=?1sBL@hD z0zljc+Z3HmGTbq@w5~#q3!3Nd<|B+9y~l)wXo|c%hE`J2v^Fv57cGjur-l)gt-S?) zWlCLbtIBcfEV)pk(c0vz>VfVixiAThD?f^Z21vLqUI)SVDh)ttzc-8g-ePW$d_d9b z^bpGnDA{>zerj9|+5X7;i&{!s9|?bnc-S%DWy8d(tg6%G;roW9C+$p?@O-mlf=xcJx=R`{=m>VU?qD@ zkQIkC2lBq+pC=+8QoI{^H15^@sZmLo0K0n@H3$$nK=(^AJpv2p;<(&uUtl}|`2Ec? zDWmk2h33Jk<-e%vC`itAUCYuG(K6LMT?6w0C}>&=HS%Yv#<+?ku=H_y#3JUJzaBNc zwh;*{LT%RX!tz|q{vF1v+j0RdM$$@fGJVZ#ex+UEkaTv&DxwaObOpm5p-%dy20Q#6fvPOPsa=f5vJ0lqUB@$z8wD~`@F4hdI z_zZtpj}KUwBarW$SshVLp!lN!&_lH~nm%9KQG;7eJ|(m+S5c16rA>bLDTE%It_VWz>ZWBmhfvW{t2&>RDg_$i|xc4H_D69+Rr4>2*%40_BQ0k&@W-^L-=&3qHG3?wRQY_0zcR*MdH!o?S{{WSQPr*%W6S}Ma0I^VF7T5|y*S$nj zcrd$FIuzh#{_zcckz0A8p@~{zn({ih^~`)2EhBj6)UT50ebB(X(U#^e*$olBu705h zjR8)wzV#|_ax>_7h!pk|^<`hEr_J&e_fnSKsU_p6m0c)49Whwprn{sQoWU&NM!hY5 zWh}8)W_WHi(|{FCcPSP)EB=rB3Q`)fweu@I1`phPN2lE43ZWUhZn08a*CZT18 zv?U>IG|eLRf-4Qo#jB_rKINUuDR;O;=H=o0W>z68)6Cpv6^zUkNcM9sUoHuImBup= zjsQDjS1NM2CP$wjB4I~^b^ickwq{QL4>00W#JTdY#e@rJ06&%|MGML{(u@~hQNnxg zqd;=Ys4%#IqyXAOI5VBS9eMQ%a9Ff5tgRmAQl*@9-GvS7I`LJ)cWk!hyd?cif&mVuAXF1`dZ ze<~A(zL-FM>i~SXJzOIzeJlGP(=rCMiG^*lK^=J}Dzx(XAW{LqPYIu~g#;c2@W6x_ zmTJ7jIJ8Z?7G7zq;wskc$CX;h_c0@S+YKl)UU`WO5o1`l!&>tkO*23RUA%7|?cx$< z9c?OMa916sv{n^O!=K%eF&vZ=34*bA#4F_DMhk$?jS$UXtaAK-U4CUQ8YLSnc%j4d z1eZx}GyU81Go3-WOxpA2A07^{tzV<_4i19kj8*;1Uxo&S&3=(BN0$0Ah;qhI6XzVm zfqF-Jb*iuX2BEd$%XsCN=2g}J>&(zWFkSs_9ESqqU1X@J(+k3K=N0C21#xB)PW+veX+(Mj{|;Z0UbN#11`g&uGOK=WbXgXG=8~ zLDmX8@hKFW!K%P`^Zx*4?FmW&>{0>iJxt&RFN~j45FS(A2Lr?o2xBbrrw67Zf+T0F z0rD{U=P3Je`WbPeWd-$YT9sJs|wyy&Hl_wnZNyRSb&MRXpGHmh?-y zdkimnVzPiy5$^N1%%;HXUue_pmGWDWzrl%4#%Xn8hP?g4wzhB^cpsW2b4aP=$`KE2 zkZFr+Va6fY#Jb68a;K=G&G&{r;2;*NGsDMG*{~u{llY2h28E}G7x{`CbZCkBz4sKfdKTUm zH)0E0citlPCoKqiexuiv4Ohs2JV$md@G*0u13bAJ*UCRlV8<{erF3ly1jBG$WsFzNP7{sH!FI+&eXGc#Y5;E?7+rvu#S+RBX!xh$4;=J4>EXpz_PteJE#?xg}%U<`ISd{k(IHW^We z2N`%}T>uf{mDmuj@d*G$WUkyLjtX4BsSd#&3uW23Yd}B($7**?Rv}9d$qes}QR3j2 z8Y@XJ!({9UZ8&FlGgnzk*YPVSF+)W@-2GmNb+W+jE3ihj9@h~n7~7{Vu`PwoeD=BB zSIi>J4}ssijVPwgYh^DL$HWB=u_x1{T&H-5B;deN%F0@49;Lcmt^It?e?j=Rh*(q$jUl0NW;R8SzQWv$BY{GY6s70k)m-9HKtVgA?yN%hQ zlhFV)ITxFN=34y13YvKy*zYZ>3x1%>pB=z6b|{T^9M^{ugXilF6laJnf-y^XM%(t2 z4wWedbL4e1fW4}6E;Qmhk&cs;3f^lzw0Q(#;U)luym#Lzg& zekw!Y#6J= zAvTR_t@4SexT#Z1^IH~V>}d4d1{y5V>d1FTyaGJh8%2GZ%> z?lnZMQpRw<)a7YkO62ehXpVbgpbPX|?>dXzwz8ZG)xmFfrDb5Yvc68DLlYd1@?!U& zb3epHE0C+M3a{Kc%z3|RJZ|@x^E{cg-ZH%F-s4g3NRJM>zie>Cs-L2o{J@sg6%PaO zjNuT{T_N*w{6fBPfUg-X4#fe$6`s#*{-!uqk?|T8gs%ib=iyI(iFR#854M5$xV024 zDPw}M_?M_&h4=;o;3efV%OAo1sF10cGQ5jlOdXgtt`4#R{{UcP2INcgOJ#1B^5>CD z#v#>f8Q*=!Oj#60UK1*c{rn)`q<2n)$G`pY6l`ozpllmZ~1cHqs7?Q~~2GuZ)ty`N#Z}q7_md-G8oPl*)YzZa3y7 zP@@QY1C?_f(73v*tAzHi4_@8o2JGhp;D#d8qjx9YH!3W&O6kJZc&i}f(yJ&wb5lbK zRKMyz-wlIMZ;Ut%}-LV{NvfDgGc2f4NDmaW>3Lzp0k<05gz$O8)=}UP}I8zxEsbOz#qfFCk44wl%t+S8*^Kh)_(OTt-sj9D6`8LO{T$%-Qz?j305w88s*;eNSjp^WqD9lY-(LVlt{(rOo_> zO*sx?1ZB#DO|5x2frbaR?l2XsIE!Fwj-b$>=BUfVPzpX$^Mhq?#3M$qE#P#%d5$C& z)0OL-x{J?Sd5fL_cs)zn1AuhxzZWw-);8F-{>#CrL`jDV2B>2RdA!HB8X|#o(?afz z32b>=F+dk(g{NJnsE>5Oxk+KjhF{_hZK;W%>eydbQ4Vt%2ZSy>MtA~J-DCnU!A~;d zn3?SGfV_Ja7}-<>^;hgg%|fgx+Vp%{x}4?Q7QMxiru`ohktM)~bU(}nS%ZC5TK!9C z3)$H+PlP4n<&{0#PpFQ~O-td-y;Pt)q+0nDJniufPKC5V);N|*9*}?El!`idzhZu7 zpRj0X`BnYI2gcf?{S`DL8bPN`7gC&|pwhl}_>bL9g*j(C&zSY5JVjbilqR%F?*%zu zQMl1``X2J%xDLQFqoL>>24#U!L@Aze2L5#(1V}@=Fh^*>~N6e#RTi{#X_HCN&x032Z(|CICj z_9hhgNamuck6BlyDM7-6_;2wI40a;?!8CV48I=L~%xd=vn49YFcLN+( zQB#ad#!$zcKBfNv;V>-5qHn9_;g({CkH;|`pSUfu62bP-pci=QErsa@W@T#Gk_{_C z0{y_t+q=LvyJg`;kT@6_Q-f|@*ptsHLV%}Mhqs_EB98z(&C>Jq?xHd>=6$SImB$k4$(xjV0?lm%oUo6ZWF#?lW#di@g5;2h5<^x@+<+ zP=L4<2HMcEcrEiV9lRL)u+M>vh+3Y#ApEf)7MCCK#A#l4{gVOYQ~)P_LEaBjF|+OtVS(OEHRPsfEmz0F&3VjefY>^F91ogfz&OW0 z&%|=f5#z+BQ}RpnWi$wOQOBs3a_vvnPsD5~aYlUq07$`Vvw!*sT8h+amRWd6I2`$d zQvM^OqAWe{Qzzm$Z{jSs{Y&xqfUtFPwU-~5h(!e+;(5EA5Pl~OKwnpakGp`;fS$ct;os<&y^pj<<8@wHY+Gg6G_-ze z%xcqFrrX1(ABJITIb>8R1EBN}XV_gW^_KofvB(AV`LKQog#d>rYdR4)yY(!DG$0;P z@3-a}VqMC(QT&8oFR|$_;y*J)`(FKJZ_<26ay&7PD=fdy(Xd{5+NL(ZTbrpzChr*b zU0`wvd`z-39mLgM7lHVtG$K7S3D6OO58lT%ot}ZhnJ<%%<;4txWe`Z>Konq^I(=R5YrH9=aVD!O;(JqSx+HC_5RDiyaK5$^k1 zmtatGjDA#c=35j_0Zl&;SC+1~AB$oj7L_*aaF;`U%T^j>1=9HGn&yBm-G6C(E(5QK zo;=E66w6Pc*zyoUkQTXb$)5=YVr)qM27c~WhEttT(>(pgB?km~GM-C?0%ZYf@%)gy>FoLJ_1B0H zM-6%$`!VVN`gx4a$EADI6XX3`t)&wE37i&lf7^U9dzhq8$GK5|!c{JfT+W z>L>pIlGcfz@t@Rj$IO2uXP611`I+$CaF=+CVN2zRtxIAy_W15qcL^-E9G#MJ#Gdx1sYN{*6w*7SvZHWsm>gY`I)HQ#@bHcYL!QZ(w*VAz zYmWdFOBGbJ<=;2aA0bLKRh857tMjQ?6;)xHt)2lD)sXf0O#iR?o(Q3ct1WD3;4z77z{KZ~=Q@JLGh*I={#b;!-yl`PzizsTEHRaLubbCR zLryJ_?xDA!gL=zX^WyEN9mYOWxeeZSttrLXQOLfkK7*j}a}AWo9kJnV0l4IVD9~5> zh{r&hTc2>tt04sv_occ6==d+WLCs9St(_TlRR)bKn=db%ADE)rUq*!b3o3vXd?*6^ zR1`6gI{yG=`++yiuL|Awv*>{Zv#iWl-GxUCp@Hi-e9~vcRpsOv;J33d6_!smJ38d^ zFqqgAg}6^(olBN&n}paoAKszCopQ5&1BC7cVuJi$<33E&5dQ$G7+UmX>D;8oU<;zq z#*CRG76-9jFYheCV4!K_HxP{6?FpL`reHi^L0zaZdKF0cG=0 ze$Y&N%I_ARO6lJYLFuq)gO<{+F7@IT2 z06Ypg=HME7ur}Ln@pCzH3RG}kX7g@ znllnxYM?Y4G}_DZEp{6@_Wb&_(%`#eDv^>~1GHOtFR54povn0Uex;oCpjNbc?!7Le zvv_b7ZSjXnaV^jZO0V0Qno6Dn$U!z#fwgPj?gqHiQ^|jtzoHsgn)1^g@Qz?rJ3k%i zfgn0iZLhD)w1p_qvVPC0t)~9~R!WAc#v&t?_2l7-Yd}XZ9UxgfSv*#l^IgnE-WooS zkN%Z~j~1EWc;IpI1)vMqKk+|N4>?d5hjpeM>6GJ&+M{(abfz4M4M~jhKBg1Z*@=7< z5Wn_c&k++*AHnfD!v6q|F#(ElOqqoy5OOAy+*2p>J~jLtd7ETP!8)b^0L&~G)5Fh9 z$-2JbrsXaPr_48NR|Bi=Bc)u_M0}uM_GMf~xXa2Cvvmo>75altjYX|q^(@azzw^D#5rf#^2K=+cDw2v?DqqcL~>|(>le&JTLG+U zmFJ4v)Kjr=?YUt`o6(}8&4cs>^U3PtW>yLw0p;*z@eL61#M`R*b-`Ki+g&)P!|2-1 zal~gH2B>go1s?Yjv%OM0gO%m<)+L~>#^v(=0Kk1@j##&B{G4a4R0m+ez!?u^2bJWN zQUx2H0&?f%V6?MARmAmldrqS8D(tpX(ibUQXI)+g)xR+?<%;>tN*L-_<%&ZdC^9ZC zE}Ax%hi{vf4ziV~w!P|zEZHF~xQFP#AXnDzRKH+L0=B?5Db&}#xl3iGLZ)+C! z<_yBom$Y`{Jf3`I)NeuHR^9<9G4M4Dw&3Q6o!n1}+gW(?8rxZwj-6A{S@2*9T7i=@ z#XRrGy+rq3(gT3)P#wM^{E<9058&UV%=zy*pK7b;60I@rAFdwaCaQ()>R~!xFE;eL*P2U($FGF3c4$epk5K+(UKUiM5 zpsgn(+zPd8P*V^03hMI@3Rvb)ob!dyJhh%`ZB=Jm_owj_t6s__iTy`bMUn*X`a5PB zQ_FuoVLr$a=gc#c5Wls@x8_@L@2I<$JHZ$N!J~<5c6yZ!A~%``O4=$Ig^mIbH5ZaK z4cEdzpf>U`;JhyWV2eO1kMnWXGo@+XzNY~qJ)k#@L%W5>G#1EPpBD|4Ww0AS`cIJ^ z`W{Cs>>c`w0F3zzyKAfU72tC#7fa~Se8!-dkP*E;L2r*}Wns5Dum|G&+Vd&NU9nql z(65z-JB^vEPs>1m$c$Jwv0Y!3xU{KuzCyWM0l0vg*CXsK9*>9)bV~NOzxNS0F|>0_ zRsOw3g-blVZ?$KuONVd`+NwQZ@#Ke*Co0;>$!?3cm>+_!3AH%Qz+D;8DHo&unPa!^ zT5>efSf>MpbX9RPXiyjeMYpA1m}pl09u z95Dr}JQjY*aT4~p!>@NR;iLy5l>Y#66qRsZ2l`XtfzLQ?!K?E}*ul+@f&(8SzPv^U zyfBj|qbI4H1;D~hS?<5rC?>ZBQBTi2r&F2&%&eZ=k=chdFDu!MdcVZ8nL-?>eLu{6 z+!oXzy^!)Yo5fj`-+||s8G@+0ZXjYKekF^~sHbsC*pB?n z0K;f;!TkKoU{0?eP)NVd58`c9{J*%yZ1vZuGWAB{9b?3x*W71kxSn?_Uo1{dyvK(e zOyA~Pcs^lQUvOFd#JsiUWUHvq0-*p?J;0V{n6p`WArqSp33xQ(=CN-}EFh-?BR$8S ze7P@iD8LS}%^mJPL^4z=cCm#v*7p`uh zbEAT&PXf-2MuS&IP*e3o^#B=MMJ;psVV8{+8!10)=!bs*ZOYkC7)Og(w}5-bQ(iT? zX#*Tqtr2(yI(@RH&Ge%KlAl6fa}#lp6<(IRS~I2FbY|9D+S`sGURsG@6}A=^-e?_T z1a_4da#2q$-$jJ4Fxt*@m*EdYR;xkt1+RenjW%2(AU|d3cN|qxuEBowPe_b93a~-r zL*fW&6jlXsc*r~m*Mi5*U-6$TbG;AEe*qXXPX|kv=lhfqt$GFzx4 zQ`VpkeosQ52m*L6+~Jn~h#HC_yLA=Uz4K4<^0Sx_g5hlkVbw!NZfVbfh;y;3J-iFI zhds=?hNGnL`O~sB)$t9OrSCBtwQ-LPpC$bfFy1?1pD^9?Fq_WhRekdabrwd@edn^v4-J@s-=idPWLSu-%D{?^MXnX!7r#N!M5V$yI?#j z>_1R`tsdoeEo4vSI5Y$*SYO%%Y|zaEZ98URAy_4Oa;s|U4DmV5C7#ngdyPae1cPNC z3Z^{76k)&_kn@_(zNT5QadQh>W8)mKvT`0lQLe)^-LM5W2uga_Djr~S>xB^%er@1~ zyugbeIPkoDvb!q$Y6G8|{KeGT75@MOGTz{$K(OP@pJYYZ5qiaXUxLRyqz-GK_o$?R zxH19Sh4xPqC`+A{O}FHruxLzWumX-$3#1~RAl5nYDmV*9mIhwz?LHllFmC1C4UM8z zyxe#M4q!XFQLI>dQvA*W2*#5S<@m>nq9=Cl@bAPH0pQAE;-Hi}cQLw;*dJ&C%7=8H zGhT?eprU6^->aH92zm{e-=o-)HL>-zA zf7m`fLA$%F>MNfDax8VD@fi2e&$s&^j>H&8i9(Ajwc_!&v4x0jS`F~~{s`Q5a^I2i zNBH6(u=X6ed9P4lLm{;H)%hcNTEPq42>C4>J<18t(Cx|h_^4qs+3>)$lA(*3OPC+b zDo`F{W>Di%=2J4dl`_I~<_qpN)EqkXF-)^^)7-|XlgwU|q~$WvUXCA{D0(JrxOaI&laDUh=zr;wHJ&sZ9GH_s%?M;DcM)8 z5K63;*4NOxd|bp4HZTpU^kymGZmukWeR?aIa50?iZtycj+w%_BFza(tX7OTr= z$_nbEqAdk5dFG4yjf7qGv;7FUtP_N;wMxcX&TPxan3kw+*Gss)d8chAUWa zZ8+iKq7_d)eJEFAmvS+mr`!yUo^g}zpTx$36+>?&e4n^c8@wJ_eJr86 z>E_yfEL*lZcn^~49L^GYBo2HeJA(m$nGQ*7<>%90*WZ z_?HzFQBl?WM&+nq1};}K65SlIgtW@pz|c$f1H?NaJZ<+2TJnO_{ZGh9~owOrch zn{0hs-Yqk#IeSV<%g7cd{MSPLfb4)jMzlYqVXivxw z<1vxYs%X8~fDf&P_8&_y{{U>IpggtZWRSRkHF*9tYU^k7C-0l7B#cO zujG6{a_%8AZ!7mGWH|QU$;tl!B~OuBye2X5KTIfOEp$5gV$&QLwzB%9Lw5{5&!6gI zV6u1$d<1-vQWQ|;_*FanAo^nv`tiw=JG{qAZS5IAUmM0L3d*HL`e?pMtjx(rgK(Q0*RA)A^b z5C&PJ5R$R@h{S@}SV*LFd?R4qP?q|#PO!{hpk zDbp!i%NPZ%US9nZH@nVW*s8hy23Ra^3w=s9?^3MJk-vDqcLUTZ3_1MDXCD!>@iD>< zwQJ02Lz;-UuU805+{*h72r6I?5XuJ<;VolqWrvt+9KPkSlI!!W^!tr~jXE6^Y4AFk zv`XrM@LCCAf$MT5hXn)khzZq8s4%O!qW)%r2Pad9x(1D2^$k>dgKb-s*X9%O4dA!- z3imnfmqu0iU*MJpxpN27>GuXfC7VSBns7c*lB~e6i*4MmS(2K;(9aEDFv+0-IQ`^m({tu~N6&C0h`6s%S%AuX{KOcq5 zAR^JK{ktC!2WZ0HnpORuOs8EF9#Fo~Ph?SDtxaame)0I2(#O&5{P-Eq0K2}M`i*yU zc1p+b_I?=?NG6zV;?1D2t~Pg?Uz%UJuakjrw3Ouk2xiX$+)0bZ#2HHZo%+g^V5KB1G|r)BL)r?ZnDNc*aFNUOLF|U zuB+PxDzH3_3xmn7S(Y~8yR4nnA52sQ2H|D$8ufDl!5gb7`F@}itOJ+%O}+`8G#inf zzN|j$%)MPR{Nn_>6%9}#j^qqy4WHBPNKi3we729$bV zq??YFT$+RUVyx3tde8lUr$wvNL4P@p>hZIWW1BLL7 zFkL$%&KFO~Ux?w#bZx*dn-9dQ^mNJHz6_ocFz8?I4yNTP7s)4evTIL}b zOrI_pa#dR3jW_^8pq96|@T-XMnf%6h1A^v-SMxAEbpmewpe|F~0jQ0{7XJVULXBl* zea{@C4**{z{{SXyhALiGc`8#HzM?L(5}Z6tc=?)IMeS|wQDZ9=GMqQH(Jwnr;_6p+ zmx33~S5o!0Hf`wf6oa&Nk<0vj!f2E&c5~zKMfPqgI3VdmJ7RIFC`JR3vs5wg6tIix}QOZ6_B#m`~-B8=yI8o$&g2sb^KHVezSNm)$!gql?q1ppWn zfmcH@#V_tOz>T*B62-KR7@EhpqAJr7$hH*s1d7nAW>o{-F@B=ozxxE*9Q6^+6DOHd zZxDj4M;5hlwc;(kwR03%(pg^8lo5SwNWDv44DhE6<_B-Nt=+6E8UA76@d$rXy}e=t z*O{|0T&iW1VF#e{&DJi{GJvjK0klI?R^ia`5e`7w=)fW^ z)n#Y{{lo-t7MYNhp0d}(dVm$r#`Aj3`NU1cvOL!hfImJWm43X-eFBYu zmRpc$cb72dSU*;x{e}F?iD7AWUyEV!E-RAnY5e>W;%+vyU7Xf;_=W=jfF^Bem@vNSsCRj{!gfi zh`IAy^NYkZmH{g>R=r|YS;KQk`=(aWLR^Je@;_iek7awsy7>mYjNTI2Im`Glk<60; z8vTmC;cEc+g!x$i0Lkwqaio3M;Vy{dJ4F@wuMbkH2HBS&_egs_WekW58&x)WrF=7i z0Ew;F@P1$ykkf9i-Ab z#dy8Kzv=+zamw;Ryx-JZIrkSnNnP9bEfMiHk>EptiNaA~(>(d{9Ddhi4Fg-D(5g2H*C4g{ zFFIOz9%f$d0^UwbZ4QLh175RZ_y{{FrQL#`R4@skW8^t}{{T-hMY)#Adl$Cf5g3j% ztGg864sG!bk6YfT7$-sQE}c320^46tZ(nQD>w`$}}6P}U0vzw{{TQC%d)BB z(ep5L!I6Gul1d==FX>l%--F&h;#YwQfs=2b#C5g>y$;V;ljQDPh=YRFYt7{D zrEm^wgGC2{#|*iEW%3+{r8g3;&HynVc6`NDR89f|@oDuHFept^VIQ1(iDem204eU< zJ<7Aic`~1y`-o}gHA3H}DD8uJ9RB|RiCBRSRUeX>m}<9kgim$(MRK(G1O2H>Xz7?E z8%AN-91K)@V&YDOxR->4^c(y@QX6gDF@owOt8O2h!Hk2Xm;*uB5Si#0Q!dD}58E{4?M-3?Ogy@pm8?Kk*_#{G zz4cHhmzbXQ%oloQSc$E3CeP4QR3}X3jgR<{)NsahZLeitT~OfQ>)!v;(!dc$6TcJAb#RSwMs5 zfcmWFt(d#Kmo4+V@vh*Ywu_~=()|+a@k1i;`OIwlEgLAy(3j020q7dy# zm$iTpgV+N3sDWGK77z#P?Fo_~u0qRWoF6ckYGM>C8}pj@SZE5X4ikSAzC6I*wzW9D z9}rnIYh0bbQu+G%3_rvLC<_jW{C*H#f;a1Y<^i1Ht5NwE7VZ^r=F7=KP}1uf)5s24xs>XsfwQE?Q;yFism({)XeqAa};n* z!sUKH?5M5XN|Vh(zYK`XFm4wGBK*n`X1w(T@xNYT=$fe1Bw#V6aolmY@g2>0mM+p9 zpfTl?nS{RmK~#)F_>7sr3=qpAfgh44V8ucD8I5JHll3h%2T(6KAekRjN_=p3$;Y|m{wcvmy zjYEqA=8w2`aJ+R+7yE>xBI~YWOWJ^D+_m**)C55bIY#%*{Ess%HtATQ4-JROGqfwZ z$au27)21TT7Vur3t^10zHZy1*Ppk9u5P5Z17_;)$UZ(+jQ$)wT{6e`_%%cAQZSf5- zOdPCH&3>a=`60@K=W95wKyprpSz_?MefxA=fWGvE~s34*O9C{Z-=#0O9X z^_bs9rJwJqmtF#SF5zyvBm=U`#dKrO`^f>=Jo|)=9~^M3ODn4npYa`$URT@-x*lfq zJ2M7J0S{=CpH5)=iuV*9FHzfHGS!9vbq%PL!LtP^Z4$V+bTu8X?J&Mp7v#z+1zi6C zC`r{_YOBYIc{XSKsYQya=3Ii``Hph1w1_U(8i;em6~)V1sqw&b^%#g31wlS$ild*v zCGC~va>e@dk1l(HvxQu1lE7QvU#q!{o57m@(1&dxs>?HA`o!e$u6c=+R;L2QJ82-jl4K1Q0|}m@oS?eWS!) zJX^wF`uw3Kb5QP)Y%-OoEyc7JLX%u%DEyDi%h8B2&OQ-@3=Z&Q{BvHDYNq1wdV0J0 zjZ&~J*QfdTVJuYy(a+d^OXqgTAS=Q-doSr4XxsBFEAq?9$_w^nhlK~8it#1?IT(9e{reQiu^_`<%6Kd@hriD29FTSV3bqb z6xssRp!4D`40x>n0K_v+R^^T3%plwPm_#5P;4$4L{KGLs2iY8EGccKz9=5|oi0lV> z4Ks&6JXOD#Zdq0cz|`Z*De>wdLN=E9QPl9|g`saS0$LJ-&LEg(D(NQ2^Zx+A1@P{t zphTLnL7iBSCz!KfEwG^Zu!>oGPne;YZJA2nn_qC!S^#-&=ZpjtrLsVUjw;3ZxLg%?6=QC^zogxv zLrY=zL-IwjX-S8g5 zz#C0?CohWnSP>{7a0&h^<{-fqrrO8uz9sQJxw824m*U7sl|&kxo*7e+@dYqVbVeCD zu6{hkjbX%Ytqc9IZNvrPID|DBRp|J&mg@;C%Y$EQ_Q6S6kTNe{nC;IPx&05+0jaQp z@?+{dM3~1t{{RyOwF8k_Bh|9m=ZD!0`*IWA=k+RNT-x=-HPAjLYK-$PmNv~Mh9w@l zA%Wgr`It7=sD_XQTlmg?5G1R5>VODziU8Ln5 z13ko8?hcVn4bk(NZA^mdpbk8)2ipL}m~W7moct zW&5kB`22|HDqdk@`+ngCdHP!Ze{!l_+}PGf$Ui5TQ>bF2FvnKw-0j_V6~**p$Q{cZ zYi=?0c6_n(cWoom#P?P1qPAEyaH#(Pe-W_hwk&DJIjgbUrA{?xDuam=~q(^&Kkq9U))u zFyIV|J|pd>Wfcrd(7%`B7A-k%?TDHM7Jm>x%U71C&uAD}3_$m#FH5*>$eM!d%>?)0 zYA+}ZcMUeOJZ@OLztsEQIU%8&Xv8&ee^V|M=P%r6f;RsE-sXeVzGd$Xbsd4EQeF_|zdoFF~B)sA;0nymp1lG+YQ~_z${)bCqZ1!NdJaG%J+_ zH;Sov0*h$S9fvZFNk~$*)Dt?l0rvjciNK){4chIwdKIFx^tJ%IY`EL%Qlsu<^Ts7h z<<-C(db)yQ`-01ef9(KkmM=P)@=@jr0Ri?$b+2~NM@!L=YTwjd0pQ z9^kc6qn#Q1s*daB`bk-oU}zdAt0flb>N!vcH=*EM+LsZY`SiJVvR1-@+(z~Fm^kGc=DsY8NaG{3N}IWNQ- zKnOXiE#kxTFC+)AE+3cfBC!*z-Ie?)XlHetg1#Xina9CB6oxv{1Gb~?--tH=5CbTG zGd@@`Hi?}70CDGZi*`847cWYq`$qw!SReDaTWSFR0B#4hjXYyffG5cwXL)dEfdYkz zJfQ+}1GrMaW-A+XW?dHP{`!{HrBL2>Z1B#*q4OzM7X28P)tW8!6bCpyhxh*gC&!rD z7Ps_&aIw4o0tp@S^)o5Tb@!v2lhm&2#rN#@paMQMdiXSXc`j zohimp^JJ{!fsG11(0si|xGrpMf$;9VN06drUr(IxkBBGE1hfa(J;126w{PTAo*+Xu z)95kcCjKzsKqsN2+Q-n6e<279PIBa`=F{;XCo(z|Kh42Z zzB4Axo-l?ng|!Yc2mZTQ^pz~`_&!1g~_)+>`Fh&LiSMX8x#)?%V#j*M@e8H`)LHSB{p+;IAL2mYEJGhO?TP5yo8_@}E zx2-|T;_4S}p%N`@mVs%C6jlAOgPPbpaV*MTXNY1Js0BORsAUU+?}*nqQ}X`+$zwps z{*USkHQ4#z^DwEkez)M2EvLfBtNdwoIZ0#e1joaKG5h_b;EPOx-!?lbE_=Kwz!<(=TcPTSV) zy*-25x$dY4$cUNZcq5A82J;$gwZgCA(yiOz*B{e6<45*roN^aMBsBxmV7_NJmnimEc;7*!oC(AOWEY6_n5C%(#%vq3clh} z+jwRGYq@YaxbsC}SryCNtD}?v{`r_UX1H3GwkRC%$7-^(Pd(6_807Y1*^Zx+O4|j^?fVlN?tviE%sA0dEx~=st4o`_o zbjEUpX?X4x1IFVQvD^jD;};D;*Zc#P6jHekdwYKnI@pXKHOuuZYuYQc;#X31Y88H7 zBGwL@snhMg{YsCC5vuW^==vp)rIZlULj$5YL#lX~RzJ5f)>4j*j3=VJ61QxdpLueO z&|_8HHXY@G1?lFlGe-3_MP`iv->M;>2V+LE^k>_0~_d_>n6oU@nnhs~h8 zAq-pqegQtB4t#j1_IT!=ARZ!!yf%(;S5STgb`{g}d9lf>rlo*{NP&tysJ z!|oabPEG=m@e61djoS)!S^P|^hPnkJ@mKgi5G3T}a$E$UN1aT<^1#AMGPSF=V6iIi zYaSqU;eV(tipLDb3sK|58614-USp0Y$Z#OEWsPXI6$UefZc?d-7b=BsZ~GR@#G>9x z{BtrBvr^W%{tW)%(RsuD`IYUSpq%l8JpTZ$ZQ8{i+xiHXDuFAgHVxB!U-26n0~HR{ zL-~{?B;y*;R!4aGf-GhVxM)pZC}U6_dB=&inl;NE5?qjan#bI)+%>< z>140P#(`-HGWUn>KQIfAuHXo5hJnFP9!i@!ZNt&!{6|8K1+*}^uCAQHZAiGVW8M5D zK$Pb`H{sSIjYv^OypMBf&J~;IiEo54$q)q0j6r7ZA-}lRxqFm6-pHuM`Htn+afZwH z19xe|s4k8E1jtXw!4-w0OMoIc>MMZmrfGQ;2lX5s)CIh?-5%9e{Xtu`t6ig&?gs{d zUH5~R{c!-gJcEmVYeZ#NdM@l7JYNy}$cubM<>lv$M%K5$PvUBW%h7NVs4}2dj{L6N z!s(ztvK~Mc^YI64(DI`r<*fB^8e5~}dvx`O5!jZVXcay;{$k+RXNCOQASHa4thf2% z1VY97r|taU=VUq-7t*D7T}8ow?bxvUBO=hAVw2|4fbJ6UW<=@P{{SGXGq{}r>ht2< z$5t4e7WLW63|-lnQ)(rcFJs2ooKT=D)W?x7oi$$489SE{Z*JJSAUW(SiLj=BSwTeUiZ zy?sBJCNcq$ICBSFt8|B_d86D6IW>zkc?a8XQsi1)viL@8&!+}55e;0h=EzJ zcbRqpmjN1z4*9I&%@$iX%&~0`EuMCygbrMG9|fpUNUfG-?Y;vAcidZnD6@ zPJ6Tc#9MWFiLuJ(oy54>EesZHH^oYQ0UAebn=>EaV#{s03US`56gOGMkm0Jo!w)i1 z1n01sK1+fYmB15Nbjp?iL+#T>w^s!Xy)2@@w7-x38bR{najdYu;x0i(=0-Sc5H2n9 z%*X7BO8V!i$%mG1AE=g8@nmQ$H)dIJ7<}ei(zpxx>IH^Q8(u;jyKuzx?cj=PQA6Sz zsdDcEhm37?pNXzPYXbcFV+3)!w2(D+!n6zQ6KzwHdV zDn_Ag6{-dJzTjintm0W$23kL;a1euonf14Tm~NTd1wLu;Ut|@ADa%m&3*rv$ZB4WhrfC0-HHRi4pdlGYKF6dyERiEHDoCSMwUW?XGefHz9!v;Jnq@q9Wyf4JvU zwDZLJwH}*=ZWCjX4gGJKT8cqhMz>O-ZMD$(9D0}}tV-?%5N!Z|h6vpbpyl=EIQ1ik zk22CT=DB<`HL9kZVkP=tE}%;Q@0esO9Bqo&!YJV7n`!0d9k%jH7zXZNU$_Wx{{S+) zw%+B%PMkbThwWqVIEqip9h7up5FC^04Xe>UN1qVS23m&P>p;@h0_Wg~t5I#e@2+C8 zLWW_}d70VMP4619)iF6s0X1dC^QbU_%SA6M+eg%_MIprCwT}!HJica7M}YbGcx8wUd#L!696Wx0qf804Cz)fw?1A?8i2X84G!ti_1{K-uS2bf^BHJf#iIwX1^fkmRQ5>`6yt)o9D^;V^!uYVSY_%H^?4TbP zb;>gK0QbnhSd{&jr0DDF{{Rs766aSxUU7l&US znZb_umUH=mMI*EV0H|nia=|LIZ$HvmP+*KN z4-R2;^iEa2;LtD3LYCQHrY0Qd>I}^146B^5O*0zcFS$UzCENQ;3L9{Elpmr~o2^15 z8{iClyZZj(z7=7;OvvOw$X61$?`%rH?-6pXp3Li6YjGt&TV4ehoA{RfCDuy3=Ix=6 z#IL9kWcU>2vv`ZJqh9T>yr9V)HeEwT{&EQbC4eAWp^Cn1JAf$W57_G!Qw3>Xqc*?l zmacf?L&qrY3w3nfbLyc=0*S+UZ_hY7jtiYEx$jfgh{Tn2ehl;f0AY1~1F&$MyKd%d z6{qm(1RH4GQFU)PoTque`AZWR+yf?c2Q{0S0F>Fm0>ekd$pXo1r*IcWW+)xSd_E=+ zYVcwxs0R?h0KAp3_bW#1+WU+F7Xxke%plx{VO~@9q91@vEygrT=`D$Bb6LvwC?V7= zhsP=OMnV;(ZR^j(9%@tM9UL9f@e<4#28+q-7v>)EVWY2Ymx~VIeXOaH%pQimyv^8V z@GZXA@Dkk8`e+`6Un47l90_S}_Wyo!~$sScvp_k zSvrT%-ZdThk4{;#3+?jpQ8i+E)Y4lDxC^Z2AnWc@COCs3%&^?QW*R|nq!0$XWgUah za4m~Zq8y_=%#N;DA@igx>eMPK7*-q}A+HhMLTGWhYkw0nNFBje<_cDVSu1X0t8@^g z1vI4%RcKL<;s`A>Wc8i8I z3DnqG1;_`SWQgkP64Tj$qp(4BS#T~X`psWZYl@ly;e0xxTM())(D_E|JeZzl4P$6l zvUE^<%rqI-cps71xQ;pq04koZTtNb??FQ3(zD-MO`T#wJGoweC!Cgjbf17I0P`WWy z3vqgZGvKfl%a)$#fEX6abQ##-Xzl!SG7(f@Zf-os2d1+1{laE!Cks~m zh>wyw+8$v-;8>vh#rQcYQHH5(pvmd_GvPs3r`IBSgun*FpnFW`k(Ff3N`P`aU)S7L zE-S1f!gB-<#me$^@eRbRc=j+AX1MM4Yp=%-D)fBzF~X^eah!h zD19)f<>ibP(Vm2+D-QO=Hq07@vrh9D9d&Hx6hf-GuaWL;@|*|UAz#d-n5TpG!(6xK zDPvhB556F@jnGrb@0j*#(zZUhfFiK8t5S*%dV^X|682MF!$%BUR!T9?xP%Sf1ukmI z&NZGRsk92$(9y2>gKq{G)umeJA5hzJQ2R1i&0l0xyO^rGpnSIu_yuBo_wgzVtqV#U zUKQQu;X{B9m$AQd7%ZQ90%h|BfPG%>U5n0Zix{(&?m_9@`MH`&d6$J(YmDOfj=Q!b zR#Z6j)H=%pJpFM^*XNa{z^MgJVx7j1TuL3@p`GXzaKJelBYv z4K0dAyVX;pf>{Fsv}*yg&|2}R(GzOj(F-0fC|4b3g@|8hW$)+9T{SZRH9Yi3r8XBTHk+m!afWu;0HAmT^)_j z##7b9;#@2%yZ-=@XsTVb7;t=I1W-`K@%Ec;5~h%5)mEVJmp~cqYblE{#*CY+&+M&=mvj7 z^DsU<>B_qu>J6Bp%Hr?I(8VzjO@n^+`-4;vRkUD(L*CDE0;3!(RcAlUpgH#(uMI;( zSon@$Zr#c&vCIqDD^i$ih!`P%*o6iPb1hpf+YPQ^d2~cAo=cb@2U3Oaxp?VOEnh_P zb2>1Bx=YYtL3$DWmA(T!%dLhE2B`l4u+zA<%;j-H!J&i>@K7?!tSjqbODdz4#6@S2aFsNq82Y zwytpXF7MHhuk}RfK%y5_dHfQRO)^M%XxB%mL^i+`;f+=}IliWp!~&=2)XEppJFqEd z+n=dP!(g0oLao*9wTmU<2b`GnE%O03vsd0ad4+|AZU;9`ej_!F;j#4R#`6}w;vlYV zKaL}NP4kN^vu{j8Y}dR?a4(#|Dt|C^DCFuoYnAK29SA~Fd4GwF+xqfic z66-A(KueMY*DfCJo@F2l1Senv;+CQ50MHBD&VJw{sxzaZ{{W_CU8S~-x!R5{{-f(a zrscr)@HR)U7(atoUzwa<6aw^5JKfIRXNsD6d>CSRktPFnte#J(IT?_O@Y&rjizAre zE+dxfKJihH6^oQ*cB(lI)ma(bu;#e~73BEt6;ulY7}lTQ%U zUVKFuaOQ4t#X#Ab-eKhMwq=ofAEG54pB%HnQ3z_FFb;?iBg41(jtE_62Q+a9aPtbs zFNiXS5^VAQTv>_jz%Re*C%$uSc<6*6O_%oh@%IxTY_U5Vmaql6b~VA&y&6|Qa6Py& z?jx(rF=A1j@cqi;iQ2P_d`432SXyu2aSCOjTCG2S&vJ)w5x~Rv%xMY+p0QuVw3Z`n zUocn21%|EF{Xd9S)6dcvKiQ6Z)^=IV>G{lN1|#N?FU=9%D8b0!dUNUqHeF-sPWKqE zt5KM~&vL{Z1~0NS7l83CUwlU;$E3?i^DLTR^SA;{^A@=-8(Xp@$(4Ki<{F2+L=U}c zSedWoD9K})e$CC}GTIvR2}T^AE)h$`tzh>!Rq*_`_?0YKM@pG`7>F7rCWck3KbWSt z)%G3q$~EtEb`7pmWZ;7ZmsjBYLleY$UlbbGD>2jqQ8}{^4}v-EsxbIG#R|bJd^q^% zhT2Oy@>HOs4cBpvCU?=|-PInVzJYuY;W3`OmU@A@uSnWDCMo(&u`&#+78q*k1NPj7 zE4xjn9cze|DZ@lO{{SgaO~_q_R=!O^r~oF$cSZhnx2b?rvaB&?>zioAqi(8^eg+o9 zTJ&`cKCA`pu#~9V0uD|UZ~jHff&zpGC5Drm=6MMc#~1D^jo4~C^XqfS8QKKpgdwxU zThRID1@P)IyGW+3k9|XC-}frub6tC|ixGCF2gwp-}DUKbb)i z&#*b=qkTf97nvW2&ZRh^9uM|k+zQi=*cWfz#}wiq^>Iy@wVsvxj+JCvvQ~r|duKnn zX6Be*B_y`2upd8t{0CqiTz2R1F$KQvh>>CIbVD zhw8w&#=X2!U;? zAah@IN5PGVTP;*?KaW!gRE@(wmgP{J8V7BExRq3ImmC~>+)!JpER!U@8TSJ*MXcJq zTWS4GuyhKx+6r~=5%wA!a5cX#h{QD-qO4V$-dH`L%p=`;$E?H?SVsH?!_4viEr3-V z^DRS3n;@;1;&4>?f+Gxi+QyPC+BF80JX8yG_spqj%o8DVVxsN14p}cS80@Q2Sf5x! zwSpL%ExE_zkzsE=#Dc7-(DZXJAsgf@0IQ{vR?zRvqjUoYYs(rt<`By68kBFiVE8jV zl&m{D*An!AO%6`ioyzfJEXoXYU{lw>-7?vzn-szIQ#e7fjOX}=lUr`;XzgbEONJ20 zFBHT!1<-v}{{S(X&7M)zss@f${=cbxv~t;K>wV-m5H3M<;pzd!9;V@t;w2o%?j^i= zlw(rPGK17uE;)+=K{tX6)zg`D{!A=pSO)sI*JIys)v#3?z5+3Vv>PU_qRPf?zDC}$ z;%mVGCyEoP+GR%FvEabayudova~rDDT1@9@XTBn$MWK62fA$#?D;cTLuJbk}1)}Z0gy~LUL4IC*5)g1G;(#A z39vMuq43{K6>52OYd&)ej+Ti#&L885H8^44PZ?bMfPsz4EY4oP-ea$4MEKq-&ij=r zhADUDf9*swGGI~j%~a7v$#S_e9sziTr+g;hI9^~&>XS$&2ff8M-HaE*8$3Zv6l*uI zTQhf}DQ@UNR1J=8{XuDH1A~Q~`hv0;(^<;ATWEgSWoYjP>y^d*uf@ji=3_wNHwLZQ z3Y!Kc6}2!yyi7%!>NRGrfcH3QruFC!9gp+=Eyty&)m__t}%MAsoujmS$}_sHAPyx@%f&pi{2p&^Lg@c zL(tS8P$&>#Jj}ZdP5mP&$lx4w`bakdY5ZRirM6|}Cr~V_p5{7Z&fqujh&F$jOqX!C zH`Hk94nK)WyKxaOk_y~N_|#4{IGXVpzJ13rO6L^J}>SRi&p(jFnvUvtF#N^w~cORyo+@6XI`S>>fJwUUHF9Q zOxZ_o`w*rn!fp0(Ux{_(;2RZ2{(6bct&^(w>6cx!q_&mhAh7AYp>oHlE zkO&^@FkmGTA)j}f>!LDW5MpR7azP4rBtq@Fw>flZfUm+9_{FPcRZ+H$Yn5uc_bn7) zWD{*WbB?Wtc3N{g3)}wyd_;XXC|l2q$8$2~?(s(a7tejh&Hn%o>IV1&;jt{{WLNAC7(-gS2Lk z1VEs?okw2c!9Qb2c8ZN0+1t%c5T$N4v-?X@2DH$4_=t0OCRb%gGMDVHQ27*T1=x9} zVw#Aw{*M``RkM`qIbbCG{{T?3A~y{j{s_66-!$O){ve}Jhi01p0ChOoDeHgxF+5u%vv`z*889P3j%mjW&POmHJdk8r6mz$xa6s$dbxW)iVswmxYs zFl%l8d5dx6xdt8l#X3oV@Lh|=vHH5AP^_WL+YZvM$>eUb_H{8r1VeYz5FI(AkU47L zD2A7wFZkyyMU_k2KbI8?h*nH}7yiu0dKc|^&GQ)P+yMhtY$A{)m6B-j^Pf=6<*$Ww zb9$?kx28~C4>igD;v6c4Jib0+{LuoAJb$wnB3DTmA>UVlYkNGk4I<~g`^_RzfzMH9v4`tTc;W=-7{%|cBoQr2JcTlfZlo(PwmwDeV1 z-NbMWu}tWq@tjK_;eRzKdcWTiLCT?9oarULLj!5Zo;NR*>2?h#m}?o(H=kG3tWz^p zeUMO~2aX^SYVw*d7nmzL_cD9${{V@9_Jx>BfFb5u{{U1dvi54bx`WwVVOQ=kdvUpI^AvD%-M;kM1JG+BKe)hOxc>0NufW7bU-T z?^RCYmaPP%kU4sk8i|{qH~Oe0x&j65!pFEY1%r8mmyG;N`T;gIxOZN!Ux+)h6tv_x zBhinJP8niRCgd6TG(_`OlJayM{-WAKiVy8_@d8x%GL_9Sjlru;uE)3h#L1*r_cuY& zxowF|k;2!6)Ga1#FjkJ9V#>lnX~n729xA1>{1w3Ee0K`=oRvR-{y#Fn4Y1N(Bl%_V zE{DNfo%wgvv^#;T^7uF%L#D-nw%3)vS@Am%y*VZe8~?!0t2~ znfWR2UdT2o25EGqcp_l|a4!Ty6~AvhKy7jTM>z+~dZZj4I*UlZNWjT;%(nd;!kWjJ z1N_2)E8e_RwNK3Zh=(fGOi(H2V+fU2Sa_ThYCPCDxp)XT-7VO9^8O}84i}#+6VMs! z;5`1Lkl-CYJiQT)w)YfmcClcTR(~n$`h!d$sh&sIKLlwZvx*(Oy-t5ALq0FdbuC;X zyK|T38wS3B#dy@B#b&Q%pB6H;X}o&o1+Im+&E^9JgOkw9Dptr}49{A}r+;zrf&s|z zd`9(e1DD14^EuEcojx5f8er}J01-M`=b3SiGEDc916;vH@_|iX7dCnGE>mOB?SSnd zt2gd2k*k>;wHXAlKQdY@T5ZyxwSlVE`r;^L!EN_k5ynT-uegj`a7wr%a$H%X#ZE;D zXCHjT%Q?JX79HzwN zQ*c~Sqhop08WBR}9>1m+_N;-?rv9c_L2mk1ET|PJC|-vW#W#FF@d+3h!+(=q+yY1e zLHU0&k~MtL!mNLFr0Or=Bg7a^| z9kZ6u^+zw_Q<1vxc+|0EEtZ}5_r#{jDVt1y;CHOTq-KqN3C1!UM#@5lG0muRyvFEf z0e-8?gK&gn-$orTbg#s}wY99agJyhEJ;O_I#J(zv{bD6D*NBDe@WTV@Is?Df4$GHjub8md@d#&*;2JZ9v8GyORactKR(H5Raj7C!3xxv#XDa%R!EaW6 zSj|vmH*4Si!GaTo3;jZAG<-_+C4jC7_Mf?t zD?Ng)y1YK+=0}qBo}dMZV{jXJi(bFv51i?Q2?bKw{Y17P#~z^)0#m=>{$+rwmTs5_ev1Q3=Hf7X8d(D)1{!9% zcx43;)1iOA666GML(;#QWod5QJ7Ee-Yq!3olvdyr#MR2tK7}Tx@Ukop?jLlBsy_@_-w9^s+Unsm?o?>1l@fpRb!R}V<3BEdD zfdRn7#POS>FEXYY&Yh|bTk9oL)Fg{0hlxv}MA7a%(5d>zxZ@P9fcd>k%b;-9$!W?G z!{ZU4*t)TInP&o|FpqNUV;Lu#+bwR)t4r4~+F`|6WJt0C4SZ!IVMG%^_v#60^qC6y zjKIv(P8n4a&I1MAa}THxs_qV)(fj2!LD2$;e9sq~x0v$L*=k*2i4(Ta8rTha<}J1D zw?2=;4*L&3q!K4^t5%%40*%e-PMbxCIxIwSDcEDKcoa-E}cQ*+pKKA92l=6GwtqLM)pX z9N$1!U=d-5ogV-+Mm;&@uFviUx`!)om3y4*7JJpd#5>FB5Gt)zhs5@x@uDRZD<{qW z01?&+b6s|R2%_|F>FQWDt83~E;wIeYd5QRWm}c89m>K|b=im1!zz?x;MtB=v5WESC z!&8Kf^@;=GDxavbsAWO}&Ou|X!NTGS6=gYn5kcXWtL7l=Zqu5?G;09BLT_FHk?|g* z7&0ZGFcvA?V3DnuuNJeRFuJ-MfC?730;~oBbv!Z*YuvuD-lRVBmiYJ-5fW4GaKgTY_Ai zo>;39n=2{@kXGgQ1V>|ab$E=Q1Wz2R&JSS?roa?p1`-qCb>bovqS5BDITbR!y06?e zg2XLHDB9d0t(QX0nP#St%9)TSOF>-|#5hp9n%mzY^~*Rl@_?+>|ZP}i7gQS-&p zv(VqTUa!XsjZmuoCM<+>V-tMN*tqD`!Tx1Fipxv;;svdtk3TAYZldV7*!vkT5$^TYp!!tcjZa8c!$n57CAIWl)AX!C<>@RHM)q< ze5IJ%$*A~<09>uj=||>J?DZ0^S}Im$Pnls`cZLhVj!;hU`hkivTvQyvQeyqEOW!P~ z(gO`uz4RlFbySb-Qp$Hr)yLGcATQ=C3|XSp^A!iS-P80&wgAC|JiNsPP;C`&K^0Be zNqdR{o!@W$Qk8H$ssKh5E9YOBMGzD}3V1%@{cdCiJ3)1x}MWPJHlsWhwAfy^lpGucN4$-uQCIrimQL0E> zHK>Nrdl#5QmTzFMJj2oQu0N@6sJA0 zut`ph#t2smr%%J$|zw%Qx`M5o7O$ATcJET@xOs^;MNu-0t6|hC zmuGbjJ5y;I+A61brIS_WTtNc%ha)T{UX#U#!ry|;2$sB0;V4T z`JP~rY7tt1?1Sy{B0h*}lvXqt&!Wfnz|F^256h`iLrT26u8%UeiM38VeMD!#tna}M zg0mUrDav900Dfh!0hv=}yhF!yEHafcL$^THz(Vk#@B^5yH2fw+l_BOC@?x1q3Ud>1 z+a(2^j60w>9nJh&g`g?`g}YjXW(N>38seFNDrJ?)5)}>E#BC%hDpaRsfi+MSV8C|9 zO97sI$^`BzU^TyRH)d|lCJj3mk>(RvLM&h6W2yU&*F&3NNUY1E&Ju{{XQczBJd9Q2ptkYt+C7fbiq{jt5Fby^sBtuDVC9HlaiZ z6}`&tTCCu#f8!M}6gL6A`syZrNDeUdD5Z=VO7Tp8+2Q>V@FzKjA5fST@4oA(?5SNr zFNXgB+l!}K$nGK1oFZJcZSexIfiMj(_Y$%TJJa9D=;CH%^HkFE-rHputa@+f8;Giy zt6~IKnRuaG#Ee6hUwrZ9hs7RFpjlQ&XhyfXnX6~>8WyUVV}pAd zLL@j*`WCj9vG71nS!cO9Z@9?MT-I*8=0A&;KofX-Z}j}Z4qT->4L8nOmu#i5rE zd5Q+V5um?7X^OR)7*BrXSs5z49O4tqYMGOoJnCA(P%nTor>L>xOv!XRj)wr+^DlUY zt$}HL8hTHxL=7_pQ|Hd*x_|*tpq~>1H4A!f-v0oowSPq-vC@js4MGjTQt_Uw>lG6Q z)ve)Lt>=mJD@iXwad7BYp_-jHbQUmKRvUJdpciW^+$4489L3c^yi2QIErzB8q9uig z#Kl(cKl>kk#}fg7e8Bm<@e1-zqcp~{Gr)4oexQ?8U`2z|Uq0T5L~FpICzxk!&0!)^E0ZecPs@BGcH|VRv2{`_D#&f zXCzay8{!OO{{ZOc3zi3hLkzQR!2W!qLmsMae5o)60-{#}u z8x7s9kNY_(vD@eRmcp&vW0Ja9^@U@@5bjx>@foCB<{}p+a|jR2XF{bKHxo`wpq993 zIO1hB?mmWZP-O~bNZKzhSgFsXgFNFuc@U6MahNVXrrWiVFHkLNoY4%z(crM!2P!oW zyrnpezy_r!g;?d{QqW;=fL1|^lu$31Ov#vJ&gY-R+&t6ALjz#g$f2wLQYnH1#6hxWZ`e80adaBvud|vKX;h6ygrpq6OA&eMj1Z?mg5VLc&u@L#%Tf z%{v?NVVP=Vaqc}&F=or)Np{e7Sm5}TSiHx6QKdg64)Otw0sI zFxC9~$IKQRQicxi*W65twmVp>_s>xR6^+r`QCV~YA%FHcWWS^&h_gw|F?q5sRbGN# zoTBL_b)dSb*GxK(nnvTr0Af`IyuwDJtkg;&Eb#>ZFH_}Y93$=ua#eG2g>A=1d51M? zc^;#Lc_!y}an4+_is)acZyYqlJbc{<<0y6-x<9-?R1j8Otb8~y#-gyBNy{k1;FXOP zT8g5T!~|n3sMD@2Cv5x{E8&>IH@(zfg>+pD5a3 zyO+nxuei!jRKXza+wlakRxTk#28mEevPe>hg_xAoGBI2r*Th(>DyVXdQCBm4Wa*fZ z+#A(Kgn*1>`GePa8v2N+iip`04rGR4CX5`)Igtey%(?4q+t0);dDYQ3Xb9gA5dsRP z#-k(X09dQd<~a#Ut?>R}3nG_oti}`tCny@`d^we*21eq5lcnBDJiw0J>@<1iQm6JR z*LN*I2`p9B$5)a%Lq%{IPE@zPG2&7x)*3DD7P%rZZC-k)Yq>dbUNG?iK}km|S^J3L W2(WWYb-WVUrF}!L-] 24.72K --.-KB/s in 0.009s + +2026-01-06 16:54:07 (2.82 MB/s) - ‘OIP-C.JzdgdK0950bF1_jTmKY46wHaEo?rs=1’ saved [25309/25309] +