File size: 3,046 Bytes
0205739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import requests
from PIL import Image
from io import BytesIO
from torchvision import transforms
from torch_geometric.data import Data, Batch

def fetch_image(url):
    default_tensor = torch.zeros((3, 224, 224))
    if not isinstance(url, str) or not url.strip():
        return default_tensor
    try:
        response = requests.get(url, timeout=2)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content)).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(img)
    except Exception:
        return default_tensor

def build_ego_graph(user_id, neighbor_dict):
    followers = neighbor_dict.get("follower", []) if neighbor_dict else []
    following = neighbor_dict.get("following", []) if neighbor_dict else []
    
    node_to_idx = {user_id: 0}
    curr_idx = 1
    edges = []
    
    for f in followers:
        if f not in node_to_idx:
            node_to_idx[f] = curr_idx
            curr_idx += 1
        edges.append([node_to_idx[f], 0])
        
    for f in following:
        if f not in node_to_idx:
            node_to_idx[f] = curr_idx
            curr_idx += 1
        edges.append([0, node_to_idx[f]])
        
    if not edges:
        edges = [[0, 0]]
        
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.ones((curr_idx, 16), dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

def collate_twibot_batch(batch_data, tokenizer):
    texts, images, graphs, labels = [], [], [], []
    for item in batch_data:
        profile = item.get("profile") or {}
        neighbor = item.get("neighbor") or {}
        tweet_list = item.get("tweet") or []
        
        bio = profile.get("description") or ""
        tweets = " ".join(tweet_list[:5])
        texts.append(f"{bio} {tweets}")
        
        img_url = profile.get("profile_image_url") or ""
        images.append(fetch_image(img_url))
        
        graphs.append(build_ego_graph(item.get("ID", ""), neighbor))
        
        raw_label = item.get("label", 0)
        try:
            if isinstance(raw_label, str):
                if raw_label.lower() == 'bot':
                    clean_label = 1
                elif raw_label.lower() == 'human':
                    clean_label = 0
                else:
                    clean_label = int(raw_label)
            else:
                clean_label = int(raw_label)
        except ValueError:
            clean_label = 0 
            
        labels.append(clean_label)
        
    text_inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
    image_tensor = torch.stack(images)
    graph_batch = Batch.from_data_list(graphs)
    label_tensor = torch.tensor(labels, dtype=torch.float).unsqueeze(1)
    
    return text_inputs, image_tensor, graph_batch, label_tensor