refactor: centralize generator cleanup in model helpers
Factor repeated next-and-close patterns into a shared helper in both VoxCPM model variants so non-streaming inference cleans up generators consistently while keeping the issue reference close to the workaround. Made-with: Cursor
This commit is contained in:
@@ -47,6 +47,16 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
|||||||
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||||
|
|
||||||
|
|
||||||
|
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
|
||||||
|
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||||
|
# does not get deferred to Python's GC/finalizer path.
|
||||||
|
def _next_and_close(gen):
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
finally:
|
||||||
|
gen.close()
|
||||||
|
|
||||||
|
|
||||||
class VoxCPMEncoderConfig(BaseModel):
|
class VoxCPMEncoderConfig(BaseModel):
|
||||||
hidden_dim: int = 1024
|
hidden_dim: int = 1024
|
||||||
ffn_dim: int = 4096
|
ffn_dim: int = 4096
|
||||||
@@ -335,7 +345,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
return get_dtype(self.config.dtype)
|
return get_dtype(self.config.dtype)
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return _next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
@@ -461,7 +471,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = _next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -569,7 +579,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_with_prompt_cache_streaming(
|
def generate_with_prompt_cache_streaming(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
@@ -688,7 +698,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = _next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -711,7 +721,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
return _next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._inference(*args, streaming=True, **kwargs)
|
return self._inference(*args, streaming=True, **kwargs)
|
||||||
|
|||||||
@@ -48,6 +48,16 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
|||||||
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||||
|
|
||||||
|
|
||||||
|
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4234809321
|
||||||
|
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||||
|
# does not get deferred to Python's GC/finalizer path.
|
||||||
|
def _next_and_close(gen):
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
finally:
|
||||||
|
gen.close()
|
||||||
|
|
||||||
|
|
||||||
# A simple function to trim audio silence using VAD, not used default
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||||
if audio.numel() == 0:
|
if audio.numel() == 0:
|
||||||
@@ -441,7 +451,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
return tokens, feats, t_mask, a_mask
|
return tokens, feats, t_mask, a_mask
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return _next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
@@ -641,7 +651,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -759,7 +769,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
return _next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_with_prompt_cache_streaming(
|
def generate_with_prompt_cache_streaming(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
@@ -928,7 +938,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = _next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -951,7 +961,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
|
feat_pred, generated_feat, _ = _next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||||
return feat_pred, generated_feat
|
return feat_pred, generated_feat
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
|||||||
Reference in New Issue
Block a user