fix: complete shared generator cleanup coverage

Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-13 17:39:05 +08:00
parent 61b36d4e56
commit 1565e83efe
4 changed files with 32 additions and 34 deletions
+9 -1
View File
@@ -8,6 +8,7 @@ from typing import Generator, Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model from .model.voxcpm2 import VoxCPM2Model
from .model.utils import next_and_close
class VoxCPM: class VoxCPM:
@@ -171,7 +172,7 @@ class VoxCPM:
) )
def generate(self, *args, **kwargs) -> np.ndarray: def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs)) return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]: def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs) return self._generate(*args, streaming=True, **kwargs)
@@ -292,8 +293,15 @@ class VoxCPM:
streaming=streaming, streaming=streaming,
) )
if streaming:
try:
for wav, _, _ in generate_result: for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy() yield wav.squeeze(0).cpu().numpy()
finally:
generate_result.close()
else:
wav, _, _ = next_and_close(generate_result)
yield wav.squeeze(0).cpu().numpy()
finally: finally:
for tmp_path in temp_files: for tmp_path in temp_files:
+10
View File
@@ -3,6 +3,16 @@ import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer): def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters. """Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
+6 -16
View File
@@ -44,17 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def _next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
class VoxCPMEncoderConfig(BaseModel): class VoxCPMEncoderConfig(BaseModel):
@@ -345,7 +335,7 @@ class VoxCPMModel(nn.Module):
return get_dtype(self.config.dtype) return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor: def generate(self, *args, **kwargs) -> torch.Tensor:
return _next_and_close(self._generate(*args, streaming=False, **kwargs)) return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs) return self._generate(*args, streaming=True, **kwargs)
@@ -471,7 +461,7 @@ class VoxCPMModel(nn.Module):
yield decode_audio yield decode_audio
break break
else: else:
latent_pred, pred_audio_feat = _next_and_close(inference_result) latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -579,7 +569,7 @@ class VoxCPMModel(nn.Module):
return merged_cache return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming( def generate_with_prompt_cache_streaming(
self, *args, **kwargs self, *args, **kwargs
@@ -698,7 +688,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
break break
else: else:
latent_pred, pred_audio_feat = _next_and_close(inference_result) latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -721,7 +711,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return _next_and_close(self._inference(*args, streaming=False, **kwargs)) return next_and_close(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs) return self._inference(*args, streaming=True, **kwargs)
+6 -16
View File
@@ -45,17 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def _next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
# A simple function to trim audio silence using VAD, not used default # A simple function to trim audio silence using VAD, not used default
@@ -451,7 +441,7 @@ class VoxCPM2Model(nn.Module):
return tokens, feats, t_mask, a_mask return tokens, feats, t_mask, a_mask
def generate(self, *args, **kwargs) -> torch.Tensor: def generate(self, *args, **kwargs) -> torch.Tensor:
return _next_and_close(self._generate(*args, streaming=False, **kwargs)) return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs) return self._generate(*args, streaming=True, **kwargs)
@@ -651,7 +641,7 @@ class VoxCPM2Model(nn.Module):
yield decode_audio yield decode_audio
break break
else: else:
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -769,7 +759,7 @@ class VoxCPM2Model(nn.Module):
return merged return merged
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming( def generate_with_prompt_cache_streaming(
self, *args, **kwargs self, *args, **kwargs
@@ -938,7 +928,7 @@ class VoxCPM2Model(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
break break
else: else:
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -961,7 +951,7 @@ class VoxCPM2Model(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
feat_pred, generated_feat, _ = _next_and_close(self._inference(*args, streaming=False, **kwargs)) feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: