Files
VoxCPM/src/voxcpm/modules/locenc/local_encoder.py
T
zengguoyang 272b8ffbf6 init
2025-09-16 11:46:47 +08:00

31 lines
1009 B
Python

import torch
import torch.nn as nn
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
from einops import rearrange
class VoxCPMLocEnc(nn.Module):
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
super().__init__()
self.config = config
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
self.encoder = MiniCPMModel(config)
def forward(self, x):
"""
x: [B, T, P, D]
"""
B, T, P, D = x.shape
x = self.in_proj(x)
special_tokens = self.special_token.expand(B, T, 1, -1)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
outputs, _ = self.encoder(x, is_causal=False)
cls_output = outputs[:, 0, :]
return rearrange(cls_output, "(b t) c -> b t c", b=B)