init
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VoxCPMLocEnc(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
||||
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
||||
self.encoder = MiniCPMModel(config)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, T, P, D]
|
||||
"""
|
||||
B, T, P, D = x.shape
|
||||
|
||||
x = self.in_proj(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
outputs, _ = self.encoder(x, is_causal=False)
|
||||
cls_output = outputs[:, 0, :]
|
||||
|
||||
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||
Reference in New Issue
Block a user