update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -119,3 +119,67 @@ def get_dtype(dtype: str):
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _has_mps() -> bool:
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||
"""
|
||||
Choose a runtime device automatically.
|
||||
|
||||
Preference order:
|
||||
- if the preferred device is available, use it
|
||||
- otherwise fall back to CUDA -> MPS -> CPU
|
||||
"""
|
||||
preferred = (preferred_device or "cuda").strip().lower()
|
||||
|
||||
if preferred.startswith("cuda") and torch.cuda.is_available():
|
||||
return preferred
|
||||
if preferred == "mps" and _has_mps():
|
||||
return "mps"
|
||||
if preferred == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if _has_mps():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
|
||||
"""
|
||||
Resolve the actual runtime device.
|
||||
|
||||
Semantics:
|
||||
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
|
||||
- otherwise: treat it as an explicit user choice and validate availability
|
||||
"""
|
||||
explicit = None if device is None else device.strip().lower()
|
||||
|
||||
if explicit is None or explicit == "auto":
|
||||
return auto_select_device(configured_device)
|
||||
|
||||
if explicit.startswith("cuda"):
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"Requested device '{device}', but CUDA is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return explicit
|
||||
if explicit == "mps":
|
||||
if not _has_mps():
|
||||
raise ValueError(
|
||||
"Requested device 'mps', but MPS is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return "mps"
|
||||
if explicit == "cpu":
|
||||
return "cpu"
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
|
||||
"'cuda', or indexed CUDA devices like 'cuda:0'."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user