diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 55c96e9..ed3fe18 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -47,6 +47,16 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device +# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321 +# Explicitly close partially-consumed generators so inference_mode cleanup +# does not get deferred to Python's GC/finalizer path. +def _next_and_close(gen): + try: + return next(gen) + finally: + gen.close() + + class VoxCPMEncoderConfig(BaseModel): hidden_dim: int = 1024 ffn_dim: int = 4096 @@ -335,7 +345,7 @@ class VoxCPMModel(nn.Module): return get_dtype(self.config.dtype) def generate(self, *args, **kwargs) -> torch.Tensor: - return next(self._generate(*args, streaming=False, **kwargs)) + return _next_and_close(self._generate(*args, streaming=False, **kwargs)) def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: return self._generate(*args, streaming=True, **kwargs) @@ -461,7 +471,7 @@ class VoxCPMModel(nn.Module): yield decode_audio break else: - latent_pred, pred_audio_feat = next(inference_result) + latent_pred, pred_audio_feat = _next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -569,7 +579,7 @@ class VoxCPMModel(nn.Module): return merged_cache def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) + return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) def generate_with_prompt_cache_streaming( self, *args, **kwargs @@ -688,7 +698,7 @@ class VoxCPMModel(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) break else: - latent_pred, pred_audio_feat = next(inference_result) + latent_pred, pred_audio_feat = _next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -711,7 +721,7 @@ class VoxCPMModel(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - return next(self._inference(*args, streaming=False, **kwargs)) + return _next_and_close(self._inference(*args, streaming=False, **kwargs)) def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: return self._inference(*args, streaming=True, **kwargs) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 3ef5592..f523872 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -48,6 +48,16 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device +# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321 +# Explicitly close partially-consumed generators so inference_mode cleanup +# does not get deferred to Python's GC/finalizer path. +def _next_and_close(gen): + try: + return next(gen) + finally: + gen.close() + + # A simple function to trim audio silence using VAD, not used default def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor: if audio.numel() == 0: @@ -441,7 +451,7 @@ class VoxCPM2Model(nn.Module): return tokens, feats, t_mask, a_mask def generate(self, *args, **kwargs) -> torch.Tensor: - return next(self._generate(*args, streaming=False, **kwargs)) + return _next_and_close(self._generate(*args, streaming=False, **kwargs)) def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: return self._generate(*args, streaming=True, **kwargs) @@ -641,7 +651,7 @@ class VoxCPM2Model(nn.Module): yield decode_audio break else: - latent_pred, pred_audio_feat, context_len = next(inference_result) + latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -759,7 +769,7 @@ class VoxCPM2Model(nn.Module): return merged def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) + return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) def generate_with_prompt_cache_streaming( self, *args, **kwargs @@ -928,7 +938,7 @@ class VoxCPM2Model(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) break else: - latent_pred, pred_audio_feat, context_len = next(inference_result) + latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -951,7 +961,7 @@ class VoxCPM2Model(nn.Module): yield (decode_audio, target_text_token, pred_audio_feat) def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs)) + feat_pred, generated_feat, _ = _next_and_close(self._inference(*args, streaming=False, **kwargs)) return feat_pred, generated_feat def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: