fix: complete shared generator cleanup coverage

Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-13 17:39:05 +08:00
parent 61b36d4e56
commit 1565e83efe
4 changed files with 32 additions and 34 deletions
+6 -16
View File
@@ -44,17 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def _next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
class VoxCPMEncoderConfig(BaseModel):
@@ -345,7 +335,7 @@ class VoxCPMModel(nn.Module):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return _next_and_close(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -471,7 +461,7 @@ class VoxCPMModel(nn.Module):
yield decode_audio
break
else:
latent_pred, pred_audio_feat = _next_and_close(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -579,7 +569,7 @@ class VoxCPMModel(nn.Module):
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
@@ -698,7 +688,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = _next_and_close(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -721,7 +711,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return _next_and_close(self._inference(*args, streaming=False, **kwargs))
return next_and_close(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)