fix: complete shared generator cleanup coverage
Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper. Made-with: Cursor
This commit is contained in:
+9
-1
@@ -8,6 +8,7 @@ from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
from .model.utils import next_and_close
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
@@ -171,7 +172,7 @@ class VoxCPM:
|
||||
)
|
||||
|
||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -292,8 +293,15 @@ class VoxCPM:
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
try:
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
finally:
|
||||
generate_result.close()
|
||||
else:
|
||||
wav, _, _ = next_and_close(generate_result)
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
for tmp_path in temp_files:
|
||||
|
||||
@@ -3,6 +3,16 @@ import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
# does not get deferred to Python's GC/finalizer path.
|
||||
def next_and_close(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
finally:
|
||||
gen.close()
|
||||
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
|
||||
|
||||
@@ -44,17 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
# does not get deferred to Python's GC/finalizer path.
|
||||
def _next_and_close(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
finally:
|
||||
gen.close()
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -345,7 +335,7 @@ class VoxCPMModel(nn.Module):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return _next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -471,7 +461,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = _next_and_close(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -579,7 +569,7 @@ class VoxCPMModel(nn.Module):
|
||||
return merged_cache
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -698,7 +688,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = _next_and_close(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -721,7 +711,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return _next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@@ -45,17 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
# does not get deferred to Python's GC/finalizer path.
|
||||
def _next_and_close(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
finally:
|
||||
gen.close()
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
||||
|
||||
|
||||
# A simple function to trim audio silence using VAD, not used default
|
||||
@@ -451,7 +441,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return tokens, feats, t_mask, a_mask
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return _next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -651,7 +641,7 @@ class VoxCPM2Model(nn.Module):
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -769,7 +759,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return merged
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -938,7 +928,7 @@ class VoxCPM2Model(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -961,7 +951,7 @@ class VoxCPM2Model(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
feat_pred, generated_feat, _ = _next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
return feat_pred, generated_feat
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
|
||||
Reference in New Issue
Block a user