Update README.md
Browse files
README.md
CHANGED
|
@@ -4,7 +4,36 @@ tags:
|
|
| 4 |
- pytorch_model_hub_mixin
|
| 5 |
---
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
- pytorch_model_hub_mixin
|
| 5 |
---
|
| 6 |
|
| 7 |
+
## ⚙️ Usage
|
| 8 |
+
Our pretrained model are made available through `rshf` and `transformers` package for easy inference.
|
| 9 |
+
|
| 10 |
+
Load and initialize:
|
| 11 |
+
```python
|
| 12 |
+
from rshf.prom3e import ProM3E
|
| 13 |
+
|
| 14 |
+
model = ProM3E.from_pretrained("MVRL/ProM3E")
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Inference:
|
| 18 |
+
```python
|
| 19 |
+
# Get precomputed embeddings from taxabind for image, sat, loc, env, text, audio
|
| 20 |
+
# Replace missing modalities with any vector
|
| 21 |
+
# Stack embeddings in the order: image, sat, loc, env, text, audio
|
| 22 |
+
# Pass through the model
|
| 23 |
+
|
| 24 |
+
# Example:
|
| 25 |
+
image_embeds = torch.randn(2, 512)
|
| 26 |
+
sat_embeds = torch.randn(2, 512)
|
| 27 |
+
loc_embeds = torch.randn(2, 512)
|
| 28 |
+
env_embeds = torch.randn(2, 512)
|
| 29 |
+
text_embeds = torch.randn(2, 512)
|
| 30 |
+
audio_embeds = torch.randn(2, 512)
|
| 31 |
+
|
| 32 |
+
modalities = torch.stack((image_embeds, sat_embeds, loc_embeds, env_embeds, text_embeds, audio_embeds), dim=1)
|
| 33 |
+
|
| 34 |
+
modalities = torch.nn.functional.normalize(modalities, dim=-1)
|
| 35 |
+
|
| 36 |
+
unmasked_modalities = [0, 2]
|
| 37 |
+
|
| 38 |
+
reconstructions, mu, log_var, hidden_repr = model.forward_inference(modalities, unmasked_modalities)
|
| 39 |
+
```
|