Skip to content

synlp/PRISM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PRISM: Learning Shared Sentiment Prototypes for Adaptive Multimodal Sentiment Analysis

This is the official PyTorch implementation of the paper:

Learning Shared Sentiment Prototypes for Adaptive Multimodal Sentiment Analysis

Requirements

pip install -r requirements.txt

Data Preparation

Download the following datasets and update the dataPath in the corresponding config files:

Dataset File Source
CMU-MOSI aligned_50.pkl MMSA
CMU-MOSEI aligned_50.pkl MMSA
CH-SIMS unaligned_39.pkl MMSA

You also need pretrained BERT weights:

  • English datasets (MOSI/MOSEI): bert-base-uncased
  • Chinese dataset (SIMS): bert-base-chinese

Update bert_pretrained and dataPath in configs/*.yaml to your local paths.

Training

# Train on CMU-MOSI
python train.py --config_file configs/mosi.yaml --gpu_id 0

# Train on CMU-MOSEI
python train.py --config_file configs/mosei.yaml --gpu_id 0

# Train on CH-SIMS
python train.py --config_file configs/sims.yaml --gpu_id 0

Hyperparameters can be overridden via command line:

python train.py --config_file configs/mosi.yaml --lr 2e-4 --bert_lr 5e-5 --batch_size 32

The best checkpoint (selected by validation MAE) is saved to ./best_cpk/ by default.

Pretrained Checkpoints

Pretrained checkpoints are hosted on Hugging Face: DeepinChens/PRISM

Download and place them in ckpt/:

# Install huggingface_hub if needed
pip install huggingface_hub

# Download all checkpoints
huggingface-cli download DeepinChens/PRISM --local-dir ckpt/

Or download manually from the link above and place the .pth files in ckpt/.

Load a pretrained model for evaluation:

import torch, yaml
from core.utils import dict_to_namespace
from models.prism_model import build_model

with open('configs/mosi.yaml') as f:
    args = dict_to_namespace(yaml.safe_load(f))

model = build_model(args)
ckpt = torch.load('ckpt/mosi_best.pth', map_location='cpu', weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages