Skip to content

Latest commit

 

History

History
116 lines (91 loc) · 19.8 KB

File metadata and controls

116 lines (91 loc) · 19.8 KB

Extension Guide: Adding ViNet & PixelNet to Notebook

Status: Reference code for extending the refactored notebook
For: PixelRec_Refactored.ipynb


Part A: ViNet (Pre-extracted Visual Features) Extension

When to Use:

  • You have pre-extracted visual features (stored as numpy/pickle)
  • Features come from ViT, ResNet50, or other encoder
  • Want faster training than PixelNet

Step 1: Load Pre-extracted Features

# After data loading (Section 2)

def load_visual_features(feature_path, n_items):
    """
    Load pre-extracted visual features.
    Expected shape: [n_items, feature_dim]
    """
    try:
        features = np.load(feature_path)  # or pickle.load()
        assert features.shape[0] == n_items, f"Feature count mismatch: {features.shape[0]} vs {n_items}"
        return torch.from_numpy(features).float()
    except:
        print(f"⚠️  Features not found at {feature_path}")
        print("   Creating dummy features for demo...")
        return torch.randn(n_items, 256)  # Random 256-dim features

# Load features (update path)
visual_features = load_visual_features(
    feature_path=os.path.join(project_root, 'features/vit_large_features.npy'),
    n_items=n_items
)
print(f"✓ Visual features shape: {visual_features.shape}")

Step 2: Create ViNet Model

class ViNetSequential(nn.Module):
    """
    ViNet: Sequential model with visual features.
    
    Combines:
    - Item ID embeddings
    - Pre-extracted visual features
    
    Architecture same as IDNet but input = [id_emb + vis_feat]
    """
    def __init__(self, n_items, vis_features, embedding_dim=64, 
                 max_seq_len=50, n_heads=2, n_layers=2):
        super().__init__()
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        
        # Visual features (fixed, not trainable)
        self.register_buffer('visual_features', vis_features)
        vis_feat_dim = vis_features.shape[-1]
        
        # ID embeddings
        self.item_embedding = nn.Embedding(n_items, embedding_dim, padding_idx=0)
        self.position_embedding = nn.Embedding(max_seq_len, embedding_dim)
        
        # Project visual features to embedding space
        self.visual_proj = nn.Linear(vis_feat_dim, embedding_dim)
        
        # Transformer encoder
        combined_dim = embedding_dim + embedding_dim  # id_emb + visual_emb
        self.transformer = SimpleTransformerEncoder(
            hidden_size=combined_dim,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=0.1
        )
        
        self.layer_norm = nn.LayerNorm(combined_dim)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, item_seq, neg_item_seq, mask):
        """
        Args:
            item_seq: [batch, seq_len+1] - positive items
            neg_item_seq: [batch, seq_len+1] - negative items
            mask: [batch, seq_len] - valid positions
        """
        batch_size, seq_len = item_seq.shape
        
        # Embed positive and negative items
        pos_id_emb = self.item_embedding(item_seq)  # [batch, seq_len+1, dim]
        neg_id_emb = self.item_embedding(neg_item_seq)
        
        # Get visual embeddings
        pos_vis_emb = self.visual_proj(self.visual_features[item_seq])  # [batch, seq_len+1, dim]
        neg_vis_emb = self.visual_proj(self.visual_features[neg_item_seq])
        
        # Concatenate ID + visual for sequences
        seq_id_emb = pos_id_emb[:, :-1, :]  # [batch, seq_len, dim]
        seq_vis_emb = pos_vis_emb[:, :-1, :]  # [batch, seq_len, dim]
        seq_emb = torch.cat([seq_id_emb, seq_vis_emb], dim=-1)  # [batch, seq_len, 2*dim]\n        \n        # Add position embeddings (to ID part only)\n        pos_ids = torch.arange(seq_len - 1, device=item_seq.device).unsqueeze(0).expand(batch_size, -1)\n        pos_emb = self.position_embedding(pos_ids).unsqueeze(-1) / 2  # Share with visual\n        seq_emb[:, :, :self.embedding_dim] = seq_emb[:, :, :self.embedding_dim] + pos_emb.squeeze(-1)\n        \n        seq_emb = self.layer_norm(seq_emb)\n        seq_emb = self.dropout(seq_emb)\n        \n        # Causal mask\n        attn_mask = torch.triu(torch.ones(seq_len - 1, seq_len - 1, device=item_seq.device) * float('-inf'), diagonal=1)\n        \n        # Transformer\n        output = self.transformer(seq_emb, attn_mask=attn_mask)  # [batch, seq_len, 2*dim]\n        \n        # Target embeddings\n        target_pos_id = pos_id_emb[:, 1:, :]  # [batch, seq_len, dim]\n        target_pos_vis = pos_vis_emb[:, 1:, :]  # [batch, seq_len, dim]\n        target_pos_emb = torch.cat([target_pos_id, target_pos_vis], dim=-1)  # [batch, seq_len, 2*dim]\n        \n        target_neg_id = neg_id_emb[:, 1:, :]\n        target_neg_vis = neg_vis_emb[:, 1:, :]\n        target_neg_emb = torch.cat([target_neg_id, target_neg_vis], dim=-1)\n        \n        # BPR loss\n        pos_score = (output * target_pos_emb).sum(dim=-1)  # [batch, seq_len]\n        neg_score = (output * target_neg_emb).sum(dim=-1)  # [batch, seq_len]\n        loss = -torch.log(torch.sigmoid(pos_score - neg_score) + 1e-8)\n        loss = (loss * mask).sum(dim=-1).mean()\n        \n        return loss\n    \n    @torch.no_grad()\n    def predict(self, item_seq):\n        \"\"\"\n        Predict scores for all items.\n        \"\"\"\n        seq_len = len(item_seq)\n        \n        # Pad\n        padded_seq = torch.zeros(1, self.max_seq_len, dtype=torch.long, device=next(self.parameters()).device)\n        padded_seq[0, -seq_len:] = torch.tensor(item_seq, dtype=torch.long, device=next(self.parameters()).device)\n        \n        # Embed\n        id_emb = self.item_embedding(padded_seq)  # [1, max_seq_len, dim]\n        vis_emb = self.visual_proj(self.visual_features[padded_seq.squeeze()])  # [max_seq_len, dim]\n        vis_emb = vis_emb.unsqueeze(0)\n        \n        seq_emb = torch.cat([id_emb, vis_emb], dim=-1)  # [1, max_seq_len, 2*dim]\n        \n        # Position\n        pos_ids = torch.arange(self.max_seq_len, device=padded_seq.device).unsqueeze(0)\n        pos_emb = self.position_embedding(pos_ids).unsqueeze(-1) / 2\n        seq_emb[:, :, :self.embedding_dim] = seq_emb[:, :, :self.embedding_dim] + pos_emb.squeeze(-1)\n        \n        seq_emb = self.layer_norm(seq_emb)\n        seq_emb = self.dropout(seq_emb)\n        \n        # Causal mask\n        attn_mask = torch.triu(torch.ones(self.max_seq_len, self.max_seq_len, device=padded_seq.device) * float('-inf'), diagonal=1)\n        \n        # Transform\n        output = self.transformer(seq_emb, attn_mask=attn_mask)  # [1, max_seq_len, 2*dim]\n        user_repr = output[0, -1, :]  # [2*dim]\n        \n        # Score all items: id_repr + vis_repr\n        all_id_embs = self.item_embedding.weight  # [n_items, dim]\n        all_vis_embs = self.visual_proj(self.visual_features)  # [n_items, dim]\n        \n        scores = torch.matmul(user_repr[:self.embedding_dim], all_id_embs.t()) + \\\n                 torch.matmul(user_repr[self.embedding_dim:], all_vis_embs.t())  # [n_items]\n        \n        return scores.cpu().numpy()\n\nprint(\"✓ ViNetSequential defined\")\n```

### Step 3: Train ViNet (Use same training loop)

```python
# Train model\nprint(\"🎯 Training ViNet Model...\\n\")\n\nvinet_model = ViNetSequential(\n    n_items=n_items,\n    vis_features=visual_features,\n    embedding_dim=64,\n    max_seq_len=50,\n    n_heads=2,\n    n_layers=2\n)\nvinet_model = vinet_model.to(device)\n\noptimizer = torch.optim.Adam(vinet_model.parameters(), lr=0.001, weight_decay=0.0)\nepochs = 10\nlosses = []\n\nfor epoch in range(epochs):\n    loss = train_epoch(vinet_model, train_data, user_sequences, n_items, optimizer, device=device)\n    losses.append(loss)\n    if (epoch + 1) % 5 == 0:\n        print(f\"  Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}\")\n\nprint(f\"\\n✓ ViNet training completed\")\n\n# Evaluate (same as IDNet)\nvinet_valid_results = Evaluator.evaluate(vinet_model, valid_data, user_sequences, device=device)\nvinet_test_results = Evaluator.evaluate(vinet_model, test_data, user_sequences, device=device)\n```

### Step 4: Compare with IDNet

```python
# Comparison
comparison_results = pd.DataFrame({\n    'Metric': list(test_results.keys()),\n    'IDNet': list(test_results.values()),\n    'ViNet': [vinet_test_results.get(m, 0) for m in test_results.keys()]\n})\n\nprint(\"\\n📊 IDNet vs ViNet Comparison (Test Set):\")\nprint(comparison_results.to_string(index=False))\nprint(f\"\\n🎯 Improvement: {(1 - np.mean([test_results[m] for m in test_results.keys()]) / np.mean([vinet_test_results[m] for m in test_results.keys()])) * 100:.1f}%\")\n```\n\n---\n\n## Part B: PixelNet (End-to-End Image Encoding) Extension\n\n### Prerequisites:\n1. **Real images** in folder or LMDB database\n2. **Image encoder** (ViT, Swin, etc.)\n3. **generate_lmdb.py** already run\n\n### Step 1: Simple Image Encoder\n\n```python\nfrom torchvision.models import vit_b_16, ResNet50_Weights\n\ndef create_image_encoder(encoder_type='vit', pretrained=True):\n    \"\"\"\n    Create image encoder.\n    Bám từ REC/model/load.py.\n    \"\"\"\n    if encoder_type == 'vit':\n        from torchvision.models import vit_b_16\n        model = vit_b_16(pretrained=pretrained)\n        # Remove classification head, keep features\n        model.heads = nn.Identity()  # Output: [batch, 768]\n        input_size = 224\n        output_dim = 768\n    \n    elif encoder_type == 'resnet50':\n        from torchvision.models import resnet50\n        model = resnet50(pretrained=pretrained)\n        model.fc = nn.Identity()  # Output: [batch, 2048]\n        input_size = 224\n        output_dim = 2048\n    \n    else:\n        raise ValueError(f\"Unknown encoder: {encoder_type}\")\n    \n    return model, input_size, output_dim\n\nimage_encoder, img_size, vis_dim = create_image_encoder('vit', pretrained=True)\nprint(f\"✓ Image encoder: ViT, output dim: {vis_dim}\")\n```\n\n### Step 2: Create Dummy Image Dataset\n\n```python\n# For demo (without LMDB)\nimport torchvision.transforms as transforms\n\nclass DummyImageDataset:\n    \"\"\"Dummy dataset - generates random images\"\"\"\n    def __init__(self, n_items, img_size=224):\n        self.n_items = n_items\n        self.img_size = img_size\n        self.transform = transforms.Compose([\n            transforms.Resize((img_size, img_size)),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5])\n        ])\n    \n    def __getitem__(self, item_id):\n        # Generate dummy image\n        dummy_img = torch.randint(0, 256, (3, self.img_size, self.img_size), dtype=torch.uint8)\n        dummy_img = dummy_img.float() / 255.0\n        dummy_img = (dummy_img - 0.5) / 0.5\n        return dummy_img\n\nimage_dataset = DummyImageDataset(n_items=n_items)\n```\n\n### Step 3: PixelNet Model\n\n```python\nclass PixelNetSequential(nn.Module):\n    \"\"\"\n    PixelNet: End-to-end Sequential model with image encoder.\n    \n    Similar to MOSASRec (REC/model/PixelNet/mosasrec.py) but simplified.\n    \"\"\"\n    def __init__(self, n_items, image_encoder, image_encoder_dim, embedding_dim=64,\n                 max_seq_len=50, n_heads=2, n_layers=2):\n        super().__init__()\n        self.n_items = n_items\n        self.embedding_dim = embedding_dim\n        self.max_seq_len = max_seq_len\n        self.image_encoder = image_encoder\n        \n        # Project image features to embedding space\n        self.image_proj = nn.Linear(image_encoder_dim, embedding_dim)\n        \n        # ID embeddings (as backup if no image)\n        self.item_embedding = nn.Embedding(n_items, embedding_dim, padding_idx=0)\n        self.position_embedding = nn.Embedding(max_seq_len, embedding_dim)\n        \n        # Transformer\n        self.transformer = SimpleTransformerEncoder(\n            hidden_size=embedding_dim,\n            n_heads=n_heads,\n            n_layers=n_layers,\n            dropout=0.1\n        )\n        \n        self.layer_norm = nn.LayerNorm(embedding_dim)\n        self.dropout = nn.Dropout(0.1)\n    \n    def forward(self, item_seq, neg_item_seq, mask, images=None):\n        \"\"\"\n        Args:\n            item_seq: [batch, seq_len+1] - positive item IDs\n            neg_item_seq: [batch, seq_len+1] - negative item IDs\n            mask: [batch, seq_len] - valid positions\n            images: [batch*2, seq_len+1, 3, H, W] - actual images (optional)\n        \"\"\"\n        batch_size, seq_len = item_seq.shape\n        \n        if images is not None:\n            # Encode images: [batch*2*(seq_len+1), 3, H, W] -> [batch*2*(seq_len+1), vis_dim]\n            bs, n_pos, c, h, w = images.shape  # n_pos = 2 for pos/neg\n            images_flat = images.view(batch_size * n_pos * (seq_len + 1), c, h, w)\n            \n            with torch.no_grad():  # Usually freeze encoder\n                img_feats = self.image_encoder(images_flat)  # [batch*2*(seq_len+1), vis_dim]\n            \n            img_feats = self.image_proj(img_feats)  # [batch*2*(seq_len+1), dim]\n            pos_emb = img_feats[:batch_size * (seq_len + 1)].view(batch_size, seq_len + 1, -1)\n            neg_emb = img_feats[batch_size * (seq_len + 1):].view(batch_size, seq_len + 1, -1)\n        else:\n            # Fallback to ID embeddings if no images\n            pos_emb = self.item_embedding(item_seq)\n            neg_emb = self.item_embedding(neg_item_seq)\n        \n        # Sequence embedding (exclude target)\n        seq_emb = pos_emb[:, :-1, :]  # [batch, seq_len, dim]\n        \n        # Add position embeddings\n        pos_ids = torch.arange(seq_len - 1, device=item_seq.device).unsqueeze(0).expand(batch_size, -1)\n        seq_emb = seq_emb + self.position_embedding(pos_ids)\n        seq_emb = self.layer_norm(seq_emb)\n        seq_emb = self.dropout(seq_emb)\n        \n        # Causal mask\n        attn_mask = torch.triu(torch.ones(seq_len - 1, seq_len - 1, device=item_seq.device) * float('-inf'), diagonal=1)\n        \n        # Transform\n        output = self.transformer(seq_emb, attn_mask=attn_mask)  # [batch, seq_len, dim]\n        \n        # Target embeddings\n        target_pos_emb = pos_emb[:, 1:, :]  # [batch, seq_len, dim]\n        target_neg_emb = neg_emb[:, 1:, :]  # [batch, seq_len, dim]\n        \n        # BPR loss\n        pos_score = (output * target_pos_emb).sum(dim=-1)  # [batch, seq_len]\n        neg_score = (output * target_neg_emb).sum(dim=-1)  # [batch, seq_len]\n        loss = -torch.log(torch.sigmoid(pos_score - neg_score) + 1e-8)\n        loss = (loss * mask).sum(dim=-1).mean()\n        \n        return loss\n    \n    @torch.no_grad()\n    def predict(self, item_seq, images=None):\n        \"\"\"\n        Predict scores for all items.\n        \"\"\"\n        seq_len = len(item_seq)\n        \n        # Pad items\n        padded_seq = torch.zeros(1, self.max_seq_len, dtype=torch.long, device=next(self.parameters()).device)\n        padded_seq[0, -seq_len:] = torch.tensor(item_seq, dtype=torch.long, device=next(self.parameters()).device)\n        \n        if images is not None:\n            # Encode images for sequence\n            # images: [seq_len, 3, H, W]\n            with torch.no_grad():\n                seq_imgs_flat = images.view(seq_len, 3, 224, 224)\n                seq_feats = self.image_encoder(seq_imgs_flat)  # [seq_len, vis_dim]\n            seq_emb = self.image_proj(seq_feats).unsqueeze(0)  # [1, seq_len, dim]\n        else:\n            seq_emb = self.item_embedding(padded_seq)  # [1, max_seq_len, dim]\n        \n        # Pad to max_seq_len\n        if seq_emb.shape[1] < self.max_seq_len:\n            pad_size = self.max_seq_len - seq_emb.shape[1]\n            seq_emb = torch.cat([torch.zeros(1, pad_size, seq_emb.shape[-1], device=seq_emb.device), seq_emb], dim=1)\n        \n        # Position\n        pos_ids = torch.arange(self.max_seq_len, device=padded_seq.device).unsqueeze(0)\n        seq_emb = seq_emb + self.position_embedding(pos_ids)\n        seq_emb = self.layer_norm(seq_emb)\n        seq_emb = self.dropout(seq_emb)\n        \n        # Causal mask\n        attn_mask = torch.triu(torch.ones(self.max_seq_len, self.max_seq_len, device=padded_seq.device) * float('-inf'), diagonal=1)\n        \n        # Transform\n        output = self.transformer(seq_emb, attn_mask=attn_mask)  # [1, max_seq_len, dim]\n        user_repr = output[0, -1, :]  # [dim]\n        \n        # Encode all items (can be expensive!)\n        # For demo, use random projections\n        all_item_reprs = torch.randn(self.n_items, self.embedding_dim, device=user_repr.device)\n        \n        scores = torch.matmul(user_repr, all_item_reprs.t())  # [n_items]\n        \n        return scores.cpu().numpy()\n\nprint(\"✓ PixelNetSequential defined\")\n```\n\n### Step 4: Train PixelNet (with dummy images)\n\n```python\n# Train model\nprint(\"🎯 Training PixelNet Model...\\n\")\n\npixel_model = PixelNetSequential(\n    n_items=n_items,\n    image_encoder=image_encoder,\n    image_encoder_dim=vis_dim,\n    embedding_dim=64,\n    max_seq_len=50,\n    n_heads=2,\n    n_layers=2\n)\npixel_model = pixel_model.to(device)\n\n# Freeze image encoder\nfor param in pixel_model.image_encoder.parameters():\n    param.requires_grad = False\n\noptimizer = torch.optim.Adam(\n    [p for p in pixel_model.parameters() if p.requires_grad],\n    lr=0.001\n)\n\nepochs = 10\nlosses = []\n\nfor epoch in range(epochs):\n    # Use training loop (adapt for images)\n    # For now, use same loop without images\n    loss = train_epoch(pixel_model, train_data, user_sequences, n_items, optimizer, device=device)\n    losses.append(loss)\n    if (epoch + 1) % 5 == 0:\n        print(f\"  Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}\")\n\nprint(f\"\\n✓ PixelNet training completed\")\n```\n\n### Step 5: Three-Way Comparison\n\n```python\n# Evaluate all models\npixel_test_results = Evaluator.evaluate(pixel_model, test_data, user_sequences, device=device)\n\n# Compare\ncomparison = pd.DataFrame({\n    'Metric': list(test_results.keys()),\n    'IDNet': list(test_results.values()),\n    'ViNet': [vinet_test_results.get(m, 0) for m in test_results.keys()],\n    'PixelNet': [pixel_test_results.get(m, 0) for m in test_results.keys()]\n})\n\nprint(\"\\n📊 Three-Way Comparison (Test Set):\")\nprint(comparison.to_string(index=False))\n\n# Calculate improvements\nidnet_baseline = np.mean(list(test_results.values()))\nvinet_improvement = (np.mean(list(vinet_test_results.values())) / idnet_baseline - 1) * 100\npixelnet_improvement = (np.mean(list(pixel_test_results.values())) / idnet_baseline - 1) * 100\n\nprint(f\"\\n📈 Improvements over IDNet:\")\nprint(f\"  ViNet:    +{vinet_improvement:.1f}%\")\nprint(f\"  PixelNet: +{pixelnet_improvement:.1f}%\")\n```\n\n---\n\n## Notes for Production Use\n\n### Real Images (LMDB)\nWhen you have real images:\n```python\nimport lmdb\n\nclass RealImageDataset:\n    def __init__(self, lmdb_path):\n        self.env = lmdb.open(lmdb_path, readonly=True)\n        self.txn = self.env.begin()\n    \n    def __getitem__(self, item_id):\n        img_bytes = self.txn.get(f'{item_id}'.encode())\n        img = Image.open(io.BytesIO(img_bytes))\n        return self.transform(img)\n```\n\n### Precompute All Item Embeddings (PixelNet)\nFor inference efficiency:\n```python\n@torch.no_grad()\ndef compute_all_item_embeddings(model, image_dataset, n_items, batch_size=32):\n    \"\"\"Precompute embeddings for all items\"\"\"\n    all_embs = []\n    for i in range(0, n_items, batch_size):\n        batch_ids = list(range(i, min(i + batch_size, n_items)))\n        batch_imgs = torch.stack([image_dataset[iid] for iid in batch_ids])\n        batch_imgs = batch_imgs.to(device)\n        embs = model.image_encoder(batch_imgs)  # [batch, vis_dim]\n        embs = model.image_proj(embs)  # [batch, embedding_dim]\n        all_embs.append(embs.cpu())\n    \n    return torch.cat(all_embs, dim=0)  # [n_items, dim]\n```\n\n---\n\n**Created:** 2026-04-08  \n**Target:** PixelRec_Refactored.ipynb Extension  \n**Status:** Reference code (copy-paste into notebook cells)\n