print debug messages to stderr instead of stdout

This commit is contained in:
vytskalt
2026-01-09 20:05:52 +02:00
parent 6ecc00a5d3
commit f2e203d5e2
2 changed files with 14 additions and 12 deletions
+7 -6
View File
@@ -19,6 +19,7 @@ limitations under the License.
"""
import os
import sys
from typing import Tuple, Union, Generator, List, Optional
import torch
@@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
@@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}")
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
def forward(
@@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1
continue
else:
@@ -683,7 +684,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1
continue
else:
@@ -868,10 +869,10 @@ class VoxCPMModel(nn.Module):
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}")
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",