31 lines
1009 B
Python
31 lines
1009 B
Python
import torch
|
|
import torch.nn as nn
|
|
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
|
from einops import rearrange
|
|
|
|
|
|
class VoxCPMLocEnc(nn.Module):
|
|
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
|
super().__init__()
|
|
self.config = config
|
|
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
|
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
|
|
|
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
|
self.encoder = MiniCPMModel(config)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: [B, T, P, D]
|
|
"""
|
|
B, T, P, D = x.shape
|
|
|
|
x = self.in_proj(x)
|
|
special_tokens = self.special_token.expand(B, T, 1, -1)
|
|
x = torch.cat([special_tokens, x], dim=2)
|
|
x = rearrange(x, "b t p c -> (b t) p c")
|
|
outputs, _ = self.encoder(x, is_causal=False)
|
|
cls_output = outputs[:, 0, :]
|
|
|
|
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|