| import torch | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| from torchvision import transforms | |
| from torch_geometric.data import Data, Batch | |
| def fetch_image(url): | |
| default_tensor = torch.zeros((3, 224, 224)) | |
| if not isinstance(url, str) or not url.strip(): | |
| return default_tensor | |
| try: | |
| response = requests.get(url, timeout=2) | |
| response.raise_for_status() | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(img) | |
| except Exception: | |
| return default_tensor | |
| def build_ego_graph(user_id, neighbor_dict): | |
| followers = neighbor_dict.get("follower", []) if neighbor_dict else [] | |
| following = neighbor_dict.get("following", []) if neighbor_dict else [] | |
| node_to_idx = {user_id: 0} | |
| curr_idx = 1 | |
| edges = [] | |
| for f in followers: | |
| if f not in node_to_idx: | |
| node_to_idx[f] = curr_idx | |
| curr_idx += 1 | |
| edges.append([node_to_idx[f], 0]) | |
| for f in following: | |
| if f not in node_to_idx: | |
| node_to_idx[f] = curr_idx | |
| curr_idx += 1 | |
| edges.append([0, node_to_idx[f]]) | |
| if not edges: | |
| edges = [[0, 0]] | |
| edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| x = torch.ones((curr_idx, 16), dtype=torch.float) | |
| return Data(x=x, edge_index=edge_index) | |
| def collate_twibot_batch(batch_data, tokenizer): | |
| texts, images, graphs, labels = [], [], [], [] | |
| for item in batch_data: | |
| profile = item.get("profile") or {} | |
| neighbor = item.get("neighbor") or {} | |
| tweet_list = item.get("tweet") or [] | |
| bio = profile.get("description") or "" | |
| tweets = " ".join(tweet_list[:5]) | |
| texts.append(f"{bio} {tweets}") | |
| img_url = profile.get("profile_image_url") or "" | |
| images.append(fetch_image(img_url)) | |
| graphs.append(build_ego_graph(item.get("ID", ""), neighbor)) | |
| raw_label = item.get("label", 0) | |
| try: | |
| if isinstance(raw_label, str): | |
| if raw_label.lower() == 'bot': | |
| clean_label = 1 | |
| elif raw_label.lower() == 'human': | |
| clean_label = 0 | |
| else: | |
| clean_label = int(raw_label) | |
| else: | |
| clean_label = int(raw_label) | |
| except ValueError: | |
| clean_label = 0 | |
| labels.append(clean_label) | |
| text_inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt") | |
| image_tensor = torch.stack(images) | |
| graph_batch = Batch.from_data_list(graphs) | |
| label_tensor = torch.tensor(labels, dtype=torch.float).unsqueeze(1) | |
| return text_inputs, image_tensor, graph_batch, label_tensor | |