From 1565e83efe7ab4c57cb8db7db25a917d29cf661f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=91=AB?= Date: Mon, 13 Apr 2026 17:39:05 +0800 Subject: [PATCH] fix: complete shared generator cleanup coverage Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper. Made-with: Cursor --- src/voxcpm/core.py | 12 ++++++++++-- src/voxcpm/model/utils.py | 10 ++++++++++ src/voxcpm/model/voxcpm.py | 22 ++++++---------------- src/voxcpm/model/voxcpm2.py | 22 ++++++---------------- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 9fd7015..81692c9 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -8,6 +8,7 @@ from typing import Generator, Optional from huggingface_hub import snapshot_download from .model.voxcpm import VoxCPMModel, LoRAConfig from .model.voxcpm2 import VoxCPM2Model +from .model.utils import next_and_close class VoxCPM: @@ -171,7 +172,7 @@ class VoxCPM: ) def generate(self, *args, **kwargs) -> np.ndarray: - return next(self._generate(*args, streaming=False, **kwargs)) + return next_and_close(self._generate(*args, streaming=False, **kwargs)) def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]: return self._generate(*args, streaming=True, **kwargs) @@ -292,7 +293,14 @@ class VoxCPM: streaming=streaming, ) - for wav, _, _ in generate_result: + if streaming: + try: + for wav, _, _ in generate_result: + yield wav.squeeze(0).cpu().numpy() + finally: + generate_result.close() + else: + wav, _, _ = next_and_close(generate_result) yield wav.squeeze(0).cpu().numpy() finally: diff --git a/src/voxcpm/model/utils.py b/src/voxcpm/model/utils.py index 904c54d..2ca6068 100644 --- a/src/voxcpm/model/utils.py +++ b/src/voxcpm/model/utils.py @@ -3,6 +3,16 @@ import torch from transformers import PreTrainedTokenizer +# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732 +# Explicitly close partially-consumed generators so inference_mode cleanup +# does not get deferred to Python's GC/finalizer path. +def next_and_close(gen): + try: + return next(gen) + finally: + gen.close() + + def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer): """Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters. diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index ed3fe18..e1c50e9 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -44,17 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device - - -# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321 -# Explicitly close partially-consumed generators so inference_mode cleanup -# does not get deferred to Python's GC/finalizer path. -def _next_and_close(gen): - try: - return next(gen) - finally: - gen.close() +from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device class VoxCPMEncoderConfig(BaseModel): @@ -345,7 +335,7 @@ class VoxCPMModel(nn.Module): return get_dtype(self.config.dtype) def generate(self, *args, **kwargs) -> torch.Tensor: - return _next_and_close(self._generate(*args, streaming=False, **kwargs)) + return next_and_close(self._generate(*args, streaming=False, **kwargs)) def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: return self._generate(*args, streaming=True, **kwargs) @@ -471,7 +461,7 @@ class VoxCPMModel(nn.Module): yield decode_audio break else: - latent_pred, pred_audio_feat = _next_and_close(inference_result) + latent_pred, pred_audio_feat = next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -579,7 +569,7 @@ class VoxCPMModel(nn.Module): return merged_cache def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) + return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) def generate_with_prompt_cache_streaming( self, *args, **kwargs @@ -698,7 +688,7 @@ class VoxCPMModel(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) break else: - latent_pred, pred_audio_feat = _next_and_close(inference_result) + latent_pred, pred_audio_feat = next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -721,7 +711,7 @@ class VoxCPMModel(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - return _next_and_close(self._inference(*args, streaming=False, **kwargs)) + return next_and_close(self._inference(*args, streaming=False, **kwargs)) def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: return self._inference(*args, streaming=True, **kwargs) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index f523872..b1bc6a3 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -45,17 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device - - -# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321 -# Explicitly close partially-consumed generators so inference_mode cleanup -# does not get deferred to Python's GC/finalizer path. -def _next_and_close(gen): - try: - return next(gen) - finally: - gen.close() +from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device # A simple function to trim audio silence using VAD, not used default @@ -451,7 +441,7 @@ class VoxCPM2Model(nn.Module): return tokens, feats, t_mask, a_mask def generate(self, *args, **kwargs) -> torch.Tensor: - return _next_and_close(self._generate(*args, streaming=False, **kwargs)) + return next_and_close(self._generate(*args, streaming=False, **kwargs)) def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: return self._generate(*args, streaming=True, **kwargs) @@ -651,7 +641,7 @@ class VoxCPM2Model(nn.Module): yield decode_audio break else: - latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) + latent_pred, pred_audio_feat, context_len = next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -769,7 +759,7 @@ class VoxCPM2Model(nn.Module): return merged def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) + return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) def generate_with_prompt_cache_streaming( self, *args, **kwargs @@ -938,7 +928,7 @@ class VoxCPM2Model(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) break else: - latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) + latent_pred, pred_audio_feat, context_len = next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -961,7 +951,7 @@ class VoxCPM2Model(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - feat_pred, generated_feat, _ = _next_and_close(self._inference(*args, streaming=False, **kwargs)) + feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs)) return feat_pred, generated_feat def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: