From 9dc54eebc29b3896bc96ae428a005ce66084a2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:20:41 +0100 Subject: [PATCH 01/17] Add support for Wan2.1 TAEHV decoding --- stable-diffusion.cpp | 59 +++++--- tae.hpp | 319 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 349 insertions(+), 29 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 2cb588213..05d9c59bb 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -400,8 +400,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -461,10 +461,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -564,14 +564,27 @@ class StableDiffusionGGML { } if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - version); - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + if (!use_tiny_autoencoder) { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + version); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); + } else { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder", + vae_decode_only, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the tae model"); + tae_first_stage->set_conv2d_direct_enabled(true); + } + } } else if (version == VERSION_CHROMA_RADIANCE) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); @@ -598,14 +611,13 @@ class StableDiffusionGGML { } first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } - if (use_tiny_autoencoder) { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); + } else if (use_tiny_autoencoder) { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder.layers", + vae_decode_only, + version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage->set_conv2d_direct_enabled(true); @@ -726,13 +738,16 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; + LOG_DEBUG("Here"); if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); } if (use_tiny_autoencoder) { + LOG_DEBUG("Here"); if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } + LOG_DEBUG("Here"); vae_params_mem_size = tae_first_stage->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; diff --git a/tae.hpp b/tae.hpp index 7f3ca449a..1323557c7 100644 --- a/tae.hpp +++ b/tae.hpp @@ -162,6 +162,230 @@ class TinyDecoder : public UnaryBlock { } }; +class TPool : public UnaryBlock { + int stride; + +public: + TPool(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = x; + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride); + } + h = conv->forward(ctx, h); + return h; + } +}; + +class TGrow : public UnaryBlock { + int stride; + +public: + TGrow(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = conv->forward(ctx, x); + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride); + } + return h; + } +}; + +class MemBlock : public GGMLBlock { + bool has_skip_conv = false; + +public: + MemBlock(int channels, int out_channels) : has_skip_conv(channels != out_channels) { + blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + if (has_skip_conv) { + blocks["skip"] = std::shared_ptr(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) { + // x: [n, channels, h, w] + auto conv0 = std::dynamic_pointer_cast(blocks["conv.0"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv.2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv.4"]); + + auto h = ggml_concat(ctx->ggml_ctx, x, past, 2); + h = conv0->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto skip = x; + if (has_skip_conv) { + auto skip_conv = std::dynamic_pointer_cast(blocks["skip"]); + skip = skip_conv->forward(ctx, x); + } + h = ggml_add_inplace(ctx->ggml_ctx, h, skip); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + return h; + } +}; + +class Clamp : public UnaryBlock { +public: + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + return ggml_scale_inplace(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, x, 1.0f / 3.0f)), + 3.0f); + } +}; + +class TinyVideoEncoder : public UnaryBlock { + int in_channels = 3; + int channels = 64; + int z_channels = 4; + int num_blocks = 3; + int num_layers = 3; + int patch_size = 1; + +public: + TinyVideoEncoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels * patch_size * patch_size, channels, {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == num_layers - 1 ? 1 : 2; + blocks[std::to_string(index++)] = std::shared_ptr(new TPool(channels, stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels, channels)); + } + } + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + // return z; + auto first_conv = std::dynamic_pointer_cast(blocks["0"]); + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(num_layers * (num_blocks + 2) + 1)]); + auto h = first_conv->forward(ctx, z); + + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + for (int i = 2; i < num_layers * (num_blocks + 2) + 2; i++) { + if (blocks.find(std::to_string(i)) == blocks.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks[std::to_string(i)]); + h = block->forward(ctx, h); + } + h = last_conv->forward(ctx, h); + return h; + } +}; + +class TinyVideoDecoder : public UnaryBlock { + int z_channels = 4; + int out_channels = 3; + int num_blocks = 3; + static const int num_layers = 3; + int channels[num_layers + 1] = {256, 128, 64, 64}; + +public: + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) { + int index = 0; + // n_f = [256, 128, 64, 64] + blocks[std::to_string(index++)] = std::shared_ptr(new Clamp()); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == 0 ? 1 : 2; + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + } + index++; // nn.Upsample() + blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); + LOG_DEBUG("Create Conv2d %d shape = %d %d", index, channels[i], channels[i + 1]); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); + } + index++; // nn.ReLU() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + LOG_DEBUG("Here"); + auto clamp = std::dynamic_pointer_cast(blocks["0"]); + auto first_conv = std::dynamic_pointer_cast(blocks["1"]); + auto h = first_conv->forward(ctx, clamp->forward(ctx, z)); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 3; + for (int i = 0; i < num_layers; i++) { + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad(ctx->ggml_ctx, h, 0, 0, 0, 1); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + // upsample + index++; + h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + } + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); + h = last_conv->forward(ctx, h); + + // shape(W, H, 3, T+3) => shape(W, H, 3, T) + h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0); + LOG_DEBUG("Here"); + return h; + } +}; + +class TAEHV : public GGMLBlock { +protected: + bool decode_only; + +public: + TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) + : decode_only(decode_only) { + int z_channels = 16; + int patch = 1; + if (version == VERSION_WAN2_2_TI2V) { + z_channels = 48; + patch = 2; + } + blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch)); + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new TinyVideoEncoder(z_channels, patch)); + } + } + + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + LOG_DEBUG("Decode"); + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + return decoder->forward(ctx, z); + } + + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + return nullptr; + // auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + // return encoder->forward(ctx, x); + } +}; + class TAESD : public GGMLBlock { protected: bool decode_only; @@ -192,18 +416,30 @@ class TAESD : public GGMLBlock { }; struct TinyAutoEncoder : public GGMLRunner { + TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) + : GGMLRunner(backend, offload_params_to_cpu) {} + virtual void compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) = 0; + + virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; +}; + +struct TinyImageAutoEncoder : public TinyAutoEncoder { TAESD taesd; bool decode_only = false; - TinyAutoEncoder(ggml_backend_t backend, - bool offload_params_to_cpu, - const String2TensorStorage& tensor_storage_map, - const std::string prefix, - bool decoder_only = true, - SDVersion version = VERSION_SD1) + TinyImageAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), - GGMLRunner(backend, offload_params_to_cpu) { + TinyAutoEncoder(backend, offload_params_to_cpu) { taesd.init(params_ctx, tensor_storage_map, prefix); } @@ -260,4 +496,73 @@ struct TinyAutoEncoder : public GGMLRunner { } }; +struct TinyVideoAutoEncoder : public TinyAutoEncoder { + TAEHV taehv; + bool decode_only = false; + + TinyVideoAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_WAN2) + : decode_only(decoder_only), + taehv(decoder_only, version), + TinyAutoEncoder(backend, offload_params_to_cpu) { + taehv.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "taehv"; + } + + bool load_from_file(const std::string& file_path, int n_threads) { + LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); + alloc_params_buffer(); + std::map taehv_tensors; + taehv.get_param_tensors(taehv_tensors); + std::set ignore_tensors; + if (decode_only) { + ignore_tensors.insert("encoder."); + } + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads); + + if (!success) { + LOG_ERROR("load tae tensors from model loader failed"); + return false; + } + + LOG_INFO("taehv model loaded"); + return success; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + z = to_backend(z); + auto runner_ctx = get_context(); + struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); + ggml_build_forward_expand(gf, out); + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } +}; + #endif // __TAE_HPP__ \ No newline at end of file From 85607ea23a7f030d1c52f4b4e616a7d473d3e7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:20:46 +0100 Subject: [PATCH 02/17] --tae instead of --taesd --- examples/common/common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/common/common.hpp b/examples/common/common.hpp index bf38379d2..38f1b07fe 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -402,7 +402,7 @@ struct SDContextParams { "--vae", "path to standalone vae model", &vae_path}, - {"", + {"--tae", "--taesd", "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", &taesd_path}, From d7fc01227d9beddef41246995f1edacdfbc1afd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:20:52 +0100 Subject: [PATCH 03/17] progress towards video support --- stable-diffusion.cpp | 5 +++++ tae.hpp | 51 ++++++++++++++++++++++---------------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 05d9c59bb..0072eade9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2312,6 +2312,10 @@ class StableDiffusionGGML { first_stage_model->free_compute_buffer(); process_vae_output_tensor(result); } else { + if (sd_version_is_wan(version)) { + x = ggml_permute(work_ctx, x, 0, 1, 3, 2); + } + if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -2322,6 +2326,7 @@ class StableDiffusionGGML { tae_first_stage->compute(n_threads, x, true, &result); } tae_first_stage->free_compute_buffer(); + } int64_t t1 = ggml_time_ms(); diff --git a/tae.hpp b/tae.hpp index 1323557c7..68d6347f5 100644 --- a/tae.hpp +++ b/tae.hpp @@ -237,19 +237,9 @@ class MemBlock : public GGMLBlock { } }; -class Clamp : public UnaryBlock { -public: - struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { - return ggml_scale_inplace(ctx->ggml_ctx, - ggml_tanh_inplace(ctx->ggml_ctx, - ggml_scale(ctx->ggml_ctx, x, 1.0f / 3.0f)), - 3.0f); - } -}; - class TinyVideoEncoder : public UnaryBlock { int in_channels = 3; - int channels = 64; + int hidden = 64; int z_channels = 4; int num_blocks = 3; int num_layers = 3; @@ -259,17 +249,17 @@ class TinyVideoEncoder : public UnaryBlock { TinyVideoEncoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels), patch_size(patch_size) { int index = 0; - blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels * patch_size * patch_size, channels, {3, 3}, {1, 1}, {1, 1})); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() for (int i = 0; i < num_layers; i++) { int stride = i == num_layers - 1 ? 1 : 2; - blocks[std::to_string(index++)] = std::shared_ptr(new TPool(channels, stride)); - blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); + blocks[std::to_string(index++)] = std::shared_ptr(new TPool(hidden, stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int j = 0; j < num_blocks; j++) { - blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels, channels)); + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(hidden, hidden)); } } - blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { @@ -301,9 +291,7 @@ class TinyVideoDecoder : public UnaryBlock { public: TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) { - int index = 0; - // n_f = [256, 128, 64, 64] - blocks[std::to_string(index++)] = std::shared_ptr(new Clamp()); + int index = 1; // Clamp() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() for (int i = 0; i < num_layers; i++) { @@ -322,11 +310,17 @@ class TinyVideoDecoder : public UnaryBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { LOG_DEBUG("Here"); - auto clamp = std::dynamic_pointer_cast(blocks["0"]); auto first_conv = std::dynamic_pointer_cast(blocks["1"]); - auto h = first_conv->forward(ctx, clamp->forward(ctx, z)); - h = ggml_relu_inplace(ctx->ggml_ctx, h); - int index = 3; + + // Clamp() + auto h = ggml_scale_inplace(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), + 3.0f); + + h = first_conv->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 3; for (int i = 0; i < num_layers; i++) { for (int j = 0; j < num_blocks; j++) { auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); @@ -350,6 +344,7 @@ class TinyVideoDecoder : public UnaryBlock { // shape(W, H, 3, T+3) => shape(W, H, 3, T) h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0); LOG_DEBUG("Here"); + print_ggml_tensor(h, true); return h; } }; @@ -357,10 +352,11 @@ class TinyVideoDecoder : public UnaryBlock { class TAEHV : public GGMLBlock { protected: bool decode_only; + SDVersion version; public: TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) - : decode_only(decode_only) { + : decode_only(decode_only), version(version) { int z_channels = 16; int patch = 1; if (version == VERSION_WAN2_2_TI2V) { @@ -376,7 +372,12 @@ class TAEHV : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { LOG_DEBUG("Decode"); auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); - return decoder->forward(ctx, z); + auto result = decoder->forward(ctx, z); + LOG_DEBUG("Decoded"); + if (sd_version_is_wan(version)) { + result = ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2); + } + return result; } struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { From d6920ccf4c36b77b73d349c1cc8f2290594d75d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:20:56 +0100 Subject: [PATCH 04/17] Wan2.1 decode not crashing anymore (still broken) --- stable-diffusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0072eade9..9db202760 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -3871,7 +3871,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { sd_ctx->sd->first_stage_model->free_params_buffer(); } From 11470b7d38846123e46966fc29aea3ce80c7bbdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:21:00 +0100 Subject: [PATCH 05/17] Less broken video decode + remove log spam --- tae.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tae.hpp b/tae.hpp index 68d6347f5..28857a90c 100644 --- a/tae.hpp +++ b/tae.hpp @@ -309,7 +309,6 @@ class TinyVideoDecoder : public UnaryBlock { } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { - LOG_DEBUG("Here"); auto first_conv = std::dynamic_pointer_cast(blocks["1"]); // Clamp() @@ -343,8 +342,6 @@ class TinyVideoDecoder : public UnaryBlock { // shape(W, H, 3, T+3) => shape(W, H, 3, T) h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0); - LOG_DEBUG("Here"); - print_ggml_tensor(h, true); return h; } }; @@ -370,12 +367,11 @@ class TAEHV : public GGMLBlock { } struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { - LOG_DEBUG("Decode"); auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto result = decoder->forward(ctx, z); - LOG_DEBUG("Decoded"); if (sd_version_is_wan(version)) { - result = ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2); + // (W, H, C, T) -> (W, H, T, C) + result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); } return result; } From 9842e34adc71bc789c734a56a9febcd16d8d530c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:22:02 +0100 Subject: [PATCH 06/17] Taehv fixes Co-authored-by: Ollin Boer Bohan --- tae.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tae.hpp b/tae.hpp index 28857a90c..2c49b23f4 100644 --- a/tae.hpp +++ b/tae.hpp @@ -224,7 +224,6 @@ class MemBlock : public GGMLBlock { h = conv1->forward(ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h); h = conv2->forward(ctx, h); - h = ggml_relu_inplace(ctx->ggml_ctx, h); auto skip = x; if (has_skip_conv) { @@ -323,7 +322,7 @@ class TinyVideoDecoder : public UnaryBlock { for (int i = 0; i < num_layers; i++) { for (int j = 0; j < num_blocks; j++) { auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); - auto mem = ggml_pad(ctx->ggml_ctx, h, 0, 0, 0, 1); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1,0); mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); h = block->forward(ctx, h, mem); } @@ -341,7 +340,7 @@ class TinyVideoDecoder : public UnaryBlock { h = last_conv->forward(ctx, h); // shape(W, H, 3, T+3) => shape(W, H, 3, T) - h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0); + h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); return h; } }; From 2fc458b0d2b0ae6e42f509795a7f9aaad62a963b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 11 Dec 2025 19:22:06 +0100 Subject: [PATCH 07/17] Adapt to lastest changes --- tae.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tae.hpp b/tae.hpp index 2c49b23f4..e4183996e 100644 --- a/tae.hpp +++ b/tae.hpp @@ -414,7 +414,7 @@ class TAESD : public GGMLBlock { struct TinyAutoEncoder : public GGMLRunner { TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) : GGMLRunner(backend, offload_params_to_cpu) {} - virtual void compute(const int n_threads, + virtual bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -548,7 +548,7 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -557,7 +557,7 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { return build_graph(z, decode_graph); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } }; From 2162441f6c907539e6ac644f2daebc79df7a669e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 12 Dec 2025 01:29:05 +0100 Subject: [PATCH 08/17] taew2.1 encode support --- tae.hpp | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/tae.hpp b/tae.hpp index e4183996e..c88dd7ce3 100644 --- a/tae.hpp +++ b/tae.hpp @@ -258,24 +258,29 @@ class TinyVideoEncoder : public UnaryBlock { blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(hidden, hidden)); } } - blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); + blocks[std::to_string(index)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { - // return z; auto first_conv = std::dynamic_pointer_cast(blocks["0"]); - auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(num_layers * (num_blocks + 2) + 1)]); auto h = first_conv->forward(ctx, z); - h = ggml_relu_inplace(ctx->ggml_ctx, h); - - for (int i = 2; i < num_layers * (num_blocks + 2) + 2; i++) { - if (blocks.find(std::to_string(i)) == blocks.end()) { - continue; + + int index = 2; + for (int i = 0; i < num_layers; i++) { + auto pool = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto conv = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + + h = pool->forward(ctx, h); + h = conv->forward(ctx, h); + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); } - auto block = std::dynamic_pointer_cast(blocks[std::to_string(i)]); - h = block->forward(ctx, h); } + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(index)]); h = last_conv->forward(ctx, h); return h; } @@ -322,7 +327,7 @@ class TinyVideoDecoder : public UnaryBlock { for (int i = 0; i < num_layers; i++) { for (int j = 0; j < num_blocks; j++) { auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); - auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1,0); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); h = block->forward(ctx, h, mem); } @@ -339,7 +344,7 @@ class TinyVideoDecoder : public UnaryBlock { auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); h = last_conv->forward(ctx, h); - // shape(W, H, 3, T+3) => shape(W, H, 3, T) + // shape(W, H, 3, 3 + T) => shape(W, H, 3, T) h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); return h; } @@ -376,9 +381,20 @@ class TAEHV : public GGMLBlock { } struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { - return nullptr; - // auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); - // return encoder->forward(ctx, x); + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + // (W, H, T, C) -> (W, H, C, T) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + int64_t num_frames = x->ne[3]; + if (num_frames % 4) { + // pad to multiple of 4 at the end + auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]); + for (int i = 0; i < 4 - num_frames % 4; i++) { + x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3); + } + } + x = encoder->forward(ctx, x); + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + return x; } }; From 9abc5132c0ea968c74c5263ca9d6b5a05d3342ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 13 Dec 2025 21:55:25 +0100 Subject: [PATCH 09/17] fix permute ctx for videeo decoding Co-authored-by: Ollin Boer Bohan --- stable-diffusion.cpp | 5 ----- tae.hpp | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9db202760..0a96ff2a8 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2312,10 +2312,6 @@ class StableDiffusionGGML { first_stage_model->free_compute_buffer(); process_vae_output_tensor(result); } else { - if (sd_version_is_wan(version)) { - x = ggml_permute(work_ctx, x, 0, 1, 3, 2); - } - if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -2326,7 +2322,6 @@ class StableDiffusionGGML { tae_first_stage->compute(n_threads, x, true, &result); } tae_first_stage->free_compute_buffer(); - } int64_t t1 = ggml_time_ms(); diff --git a/tae.hpp b/tae.hpp index c88dd7ce3..d08f8086d 100644 --- a/tae.hpp +++ b/tae.hpp @@ -372,6 +372,10 @@ class TAEHV : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + if (sd_version_is_wan(version)) { + // (W, H, C, T) -> (W, H, T, C) + z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2)); + } auto result = decoder->forward(ctx, z); if (sd_version_is_wan(version)) { // (W, H, C, T) -> (W, H, T, C) From 068d928ad54d1ffcc1e5898ebf1f81b64f69be52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 13 Dec 2025 23:11:06 +0100 Subject: [PATCH 10/17] taehv: support patchified latents --- tae.hpp | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/tae.hpp b/tae.hpp index d08f8086d..df6d4a6c7 100644 --- a/tae.hpp +++ b/tae.hpp @@ -264,13 +264,13 @@ class TinyVideoEncoder : public UnaryBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { auto first_conv = std::dynamic_pointer_cast(blocks["0"]); auto h = first_conv->forward(ctx, z); - h = ggml_relu_inplace(ctx->ggml_ctx, h); - + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 2; for (int i = 0; i < num_layers; i++) { auto pool = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); auto conv = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); - + h = pool->forward(ctx, h); h = conv->forward(ctx, h); for (int j = 0; j < num_blocks; j++) { @@ -280,21 +280,77 @@ class TinyVideoEncoder : public UnaryBlock { h = block->forward(ctx, h, mem); } } - auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(index)]); - h = last_conv->forward(ctx, h); + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(index)]); + h = last_conv->forward(ctx, h); return h; } }; + +struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c, h*q, w*r] + // return: [f, b*c*r*q, h, w] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + + int64_t w_in = x->ne[0]; + int64_t h_in = x->ne[1]; + int64_t cb = x->ne[2]; // b*c + int64_t f = x->ne[3]; + + int64_t w = w_in / r; + int64_t h = h_in / q; + + x = ggml_reshape_4d(ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, r, h*q, w] + x = ggml_reshape_4d(ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c*r, q, h, w] + x = ggml_reshape_4d(ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w] + return x; +} + +struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c*r*q, h, w] + // return: [f, b*c, h*q, w*r] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + int64_t c = x->ne[2] / b / q / r; + int64_t f = x->ne[3]; + int64_t h = x->ne[1]; + int64_t w = x->ne[0]; + + + x = ggml_reshape_4d(ctx, x, w * h, q * r, c * b, f); // [f, b*c, r*q, h*w] + x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q*h, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [f*b*c, q*h, w, r] + x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [f*b*c, q, h, w*r] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, h, q, w*r] + x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r] + return x; +} + class TinyVideoDecoder : public UnaryBlock { int z_channels = 4; int out_channels = 3; int num_blocks = 3; static const int num_layers = 3; int channels[num_layers + 1] = {256, 128, 64, 64}; + int patch_size = 1; public: - TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) { + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels), patch_size(patch_size) { int index = 1; // Clamp() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() @@ -343,7 +399,9 @@ class TinyVideoDecoder : public UnaryBlock { auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); h = last_conv->forward(ctx, h); - + if (patch_size > 1) { + h = unpatchify(ctx->ggml_ctx, h, patch_size, 1); + } // shape(W, H, 3, 3 + T) => shape(W, H, 3, T) h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); return h; @@ -376,7 +434,7 @@ class TAEHV : public GGMLBlock { // (W, H, C, T) -> (W, H, T, C) z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2)); } - auto result = decoder->forward(ctx, z); + auto result = decoder->forward(ctx, z); if (sd_version_is_wan(version)) { // (W, H, C, T) -> (W, H, T, C) result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); @@ -387,7 +445,7 @@ class TAEHV : public GGMLBlock { struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); // (W, H, T, C) -> (W, H, C, T) - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); int64_t num_frames = x->ne[3]; if (num_frames % 4) { // pad to multiple of 4 at the end From e4cbcdc801a67db57a138f386725efa9a7f8fc51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 13 Dec 2025 23:31:49 +0100 Subject: [PATCH 11/17] fix patched pixels order --- tae.hpp | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tae.hpp b/tae.hpp index df6d4a6c7..f092adb46 100644 --- a/tae.hpp +++ b/tae.hpp @@ -299,19 +299,20 @@ struct ggml_tensor* patchify(struct ggml_context* ctx, int64_t r = patch_size; int64_t q = patch_size; - int64_t w_in = x->ne[0]; - int64_t h_in = x->ne[1]; - int64_t cb = x->ne[2]; // b*c - int64_t f = x->ne[3]; + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t C = x->ne[2]; + int64_t f = x->ne[3]; + + int64_t w = W / r; + int64_t h = H / q; + + x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f] + x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f] + x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w] - int64_t w = w_in / r; - int64_t h = h_in / q; - - x = ggml_reshape_4d(ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, r, h*q, w] - x = ggml_reshape_4d(ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c*r, q, h, w] - x = ggml_reshape_4d(ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w] return x; } @@ -332,12 +333,12 @@ struct ggml_tensor* unpatchify(struct ggml_context* ctx, int64_t w = x->ne[0]; - x = ggml_reshape_4d(ctx, x, w * h, q * r, c * b, f); // [f, b*c, r*q, h*w] - x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q*h, w] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [f*b*c, q*h, w, r] - x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [f*b*c, q, h, w*r] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, h, q, w*r] - x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r] + x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); + return x; } From a7a791d4c0a814aade2344009382fd9a1dd8377e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 13 Dec 2025 23:49:58 +0100 Subject: [PATCH 12/17] taehv: patchify encode --- tae.hpp | 119 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/tae.hpp b/tae.hpp index f092adb46..056c7384e 100644 --- a/tae.hpp +++ b/tae.hpp @@ -236,6 +236,60 @@ class MemBlock : public GGMLBlock { } }; +struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c, h*q, w*r] + // return: [f, b*c*r*q, h, w] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t C = x->ne[2]; + int64_t f = x->ne[3]; + + int64_t w = W / r; + int64_t h = H / q; + + x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f] + x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f] + x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w] + + return x; +} + +struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [f, b*c*r*q, h, w] + // return: [f, b*c, h*q, w*r] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + int64_t c = x->ne[2] / b / q / r; + int64_t f = x->ne[3]; + int64_t h = x->ne[1]; + int64_t w = x->ne[0]; + + x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f] + x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); + + return x; +} + class TinyVideoEncoder : public UnaryBlock { int in_channels = 3; int hidden = 64; @@ -263,8 +317,13 @@ class TinyVideoEncoder : public UnaryBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { auto first_conv = std::dynamic_pointer_cast(blocks["0"]); - auto h = first_conv->forward(ctx, z); - h = ggml_relu_inplace(ctx->ggml_ctx, h); + + if (patch_size > 1) { + z = patchify(ctx->ggml_ctx, z, patch_size, 1); + } + + auto h = first_conv->forward(ctx, z); + h = ggml_relu_inplace(ctx->ggml_ctx, h); int index = 2; for (int i = 0; i < num_layers; i++) { @@ -286,62 +345,6 @@ class TinyVideoEncoder : public UnaryBlock { } }; - -struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size, - int64_t b = 1) { - // x: [f, b*c, h*q, w*r] - // return: [f, b*c*r*q, h, w] - if (patch_size == 1) { - return x; - } - int64_t r = patch_size; - int64_t q = patch_size; - - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - int64_t C = x->ne[2]; - int64_t f = x->ne[3]; - - int64_t w = W / r; - int64_t h = H / q; - - x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f] - x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f] - x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w] - - return x; -} - -struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size, - int64_t b = 1) { - // x: [f, b*c*r*q, h, w] - // return: [f, b*c, h*q, w*r] - if (patch_size == 1) { - return x; - } - int64_t r = patch_size; - int64_t q = patch_size; - int64_t c = x->ne[2] / b / q / r; - int64_t f = x->ne[3]; - int64_t h = x->ne[1]; - int64_t w = x->ne[0]; - - - x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f] - x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f] - x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); - - return x; -} - class TinyVideoDecoder : public UnaryBlock { int z_channels = 4; int out_channels = 3; From 6a653e6f12a648c6465d902021d884028d5e5a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 15 Dec 2025 03:14:29 +0100 Subject: [PATCH 13/17] fix img2vid --- stable-diffusion.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0a96ff2a8..a3c885f7f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -3633,7 +3633,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); ggml_set_f32(denoise_mask, 1.f); - sd_ctx->sd->process_latent_out(init_latent); + if (!sd_ctx->sd->use_tiny_autoencoder) + sd_ctx->sd->process_latent_out(init_latent); ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3); @@ -3643,7 +3644,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } }); - sd_ctx->sd->process_latent_in(init_latent); + if (!sd_ctx->sd->use_tiny_autoencoder) + sd_ctx->sd->process_latent_in(init_latent); int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); From 8a15925496c2e89c3911ebb4eb96a4fe9b6d0350 Mon Sep 17 00:00:00 2001 From: leejet Date: Tue, 16 Dec 2025 22:47:36 +0800 Subject: [PATCH 14/17] format code --- stable-diffusion.cpp | 12 ++++++------ tae.hpp | 12 ++++++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 2241a7edc..e09137688 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -398,8 +398,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -459,10 +459,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, diff --git a/tae.hpp b/tae.hpp index 056c7384e..5a4ed9f39 100644 --- a/tae.hpp +++ b/tae.hpp @@ -166,7 +166,8 @@ class TPool : public UnaryBlock { int stride; public: - TPool(int channels, int stride) : stride(stride) { + TPool(int channels, int stride) + : stride(stride) { blocks["conv"] = std::shared_ptr(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); } @@ -185,7 +186,8 @@ class TGrow : public UnaryBlock { int stride; public: - TGrow(int channels, int stride) : stride(stride) { + TGrow(int channels, int stride) + : stride(stride) { blocks["conv"] = std::shared_ptr(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); } @@ -203,7 +205,8 @@ class MemBlock : public GGMLBlock { bool has_skip_conv = false; public: - MemBlock(int channels, int out_channels) : has_skip_conv(channels != out_channels) { + MemBlock(int channels, int out_channels) + : has_skip_conv(channels != out_channels) { blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); @@ -354,7 +357,8 @@ class TinyVideoDecoder : public UnaryBlock { int patch_size = 1; public: - TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels), patch_size(patch_size) { + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { int index = 1; // Clamp() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() From a7dd964bf7ecda3154683b2f358e771429aed30e Mon Sep 17 00:00:00 2001 From: leejet Date: Tue, 16 Dec 2025 22:49:08 +0800 Subject: [PATCH 15/17] remove some debug log --- stable-diffusion.cpp | 3 --- tae.hpp | 1 - 2 files changed, 4 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e09137688..44bd3ccac 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -736,16 +736,13 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; - LOG_DEBUG("Here"); if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); } if (use_tiny_autoencoder) { - LOG_DEBUG("Here"); if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } - LOG_DEBUG("Here"); vae_params_mem_size = tae_first_stage->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; diff --git a/tae.hpp b/tae.hpp index 5a4ed9f39..5da76e692 100644 --- a/tae.hpp +++ b/tae.hpp @@ -369,7 +369,6 @@ class TinyVideoDecoder : public UnaryBlock { } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); - LOG_DEBUG("Create Conv2d %d shape = %d %d", index, channels[i], channels[i + 1]); blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); } index++; // nn.ReLU() From c6e6144986e4083f936f0b94d8e013b981c96c5e Mon Sep 17 00:00:00 2001 From: leejet Date: Tue, 16 Dec 2025 22:52:11 +0800 Subject: [PATCH 16/17] set option correctly --- examples/common/common.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 3fdc5c32d..2d890af8a 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -402,10 +402,14 @@ struct SDContextParams { "--vae", "path to standalone vae model", &vae_path}, - {"--tae", + {"", "--taesd", "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", &taesd_path}, + {"", + "--tae", + "alias of --taesd", + &taesd_path}, {"", "--control-net", "path to control net model", From 15179850607b8377ee1b1bf7c5ee775d9fcc8efe Mon Sep 17 00:00:00 2001 From: leejet Date: Tue, 16 Dec 2025 22:54:26 +0800 Subject: [PATCH 17/17] update docs --- examples/cli/README.md | 1 + examples/server/README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/cli/README.md b/examples/cli/README.md index 02650f703..8531b2aed 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -31,6 +31,7 @@ Context Options: --high-noise-diffusion-model path to the standalone high noise diffusion model --vae path to standalone vae model --taesd path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + --tae alias of --taesd --control-net path to control net model --embd-dir embeddings directory --lora-model-dir lora model directory diff --git a/examples/server/README.md b/examples/server/README.md index 43c5d5f57..a475856fa 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -24,6 +24,7 @@ Context Options: --high-noise-diffusion-model path to the standalone high noise diffusion model --vae path to standalone vae model --taesd path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + --tae alias of --taesd --control-net path to control net model --embd-dir embeddings directory --lora-model-dir lora model directory