File size: 2,597 Bytes
0205739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
from transformers import AutoModel
from torchvision import models
from torch_geometric.nn import GCNConv

class TextEncoder(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased"):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        self.out_dim = self.transformer.config.hidden_size

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

class VisionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Using weights=... is the modern, non-deprecated way to load ResNet
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.out_dim = 512

    def forward(self, images):
        x = self.feature_extractor(images)
        return x.view(x.size(0), -1)

class GraphEncoder(nn.Module):
    def __init__(self, in_channels=16, hidden_channels=64, out_channels=128):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.relu = nn.ReLU()
        self.out_dim = out_channels

    def forward(self, batch_data):
        x, edge_index = batch_data.x, batch_data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        central_node_indices = batch_data.ptr[:-1]
        return x[central_node_indices]

class MultiModalDetector(nn.Module):
    def __init__(self, text_dim=768, vision_dim=512, graph_dim=128):
        super().__init__()
        self.text_encoder = TextEncoder()
        self.vision_encoder = VisionEncoder()
        self.graph_encoder = GraphEncoder(out_channels=graph_dim)

        fused_dim = text_dim + vision_dim + graph_dim

        self.mlp = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, text_inputs, images, graph_batch):
        text_emb = self.text_encoder(text_inputs['input_ids'], text_inputs['attention_mask'])
        vision_emb = self.vision_encoder(images)
        graph_emb = self.graph_encoder(graph_batch)
        fused_vector = torch.cat([text_emb, vision_emb, graph_emb], dim=1)
        return self.mlp(fused_vector)