Feature Extraction
sentence-transformers
Safetensors
Transformers
English
mistraldual
sentence-similarity
custom_code
Instructions to use GeoGPT-Research-Project/GeoEmbedding with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use GeoGPT-Research-Project/GeoEmbedding with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("GeoGPT-Research-Project/GeoEmbedding", trust_remote_code=True) sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Transformers
How to use GeoGPT-Research-Project/GeoEmbedding with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="GeoGPT-Research-Project/GeoEmbedding", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("GeoGPT-Research-Project/GeoEmbedding", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload 16 files
Browse files- .gitattributes +1 -0
- 1_Pooling/config.json +10 -0
- README.md +44 -0
- config.json +31 -0
- config_sentence_transformers.json +10 -0
- configuration_mistral_dual.py +13 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +297 -0
- modeling_mistral_dual.py +176 -0
- modules.json +14 -0
- sentence_bert_config.json +4 -0
- special_tokens_map.json +35 -0
- tokenizer.json +3 -0
- tokenizer_config.json +55 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
1_Pooling/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"word_embedding_dimension": 4096,
|
| 3 |
+
"pooling_mode_cls_token": false,
|
| 4 |
+
"pooling_mode_mean_tokens": false,
|
| 5 |
+
"pooling_mode_max_tokens": false,
|
| 6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
| 7 |
+
"pooling_mode_weightedmean_tokens": false,
|
| 8 |
+
"pooling_mode_lasttoken": true,
|
| 9 |
+
"include_prompt": true
|
| 10 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card for GeoEmbedding
|
| 2 |
+
The GeoEmbedding model is a geoscience-specific text embedding model built upon a high-performance large language model and fine-tuned on both general-purpose and in-domain geoscientific datasets. It produces accurate, context-aware vector representations of geoscientific texts, forming the backbone of vector-based retrieval in the RAG pipeline.
|
| 3 |
+
|
| 4 |
+
## Quick Start
|
| 5 |
+
To load the GeoEmbedding model with Transformer, use the following snippet:
|
| 6 |
+
```python
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
task_description = 'Given a web search query, retrieve relevant passages that answer the query'
|
| 11 |
+
def get_detailed_instruct(task_description: str, query: str) -> str:
|
| 12 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
| 13 |
+
|
| 14 |
+
model_name_or_path = 'GeoGPT/GeoEmbedding'
|
| 15 |
+
|
| 16 |
+
model = SentenceTransformer(model_name_or_path, device="cuda", trust_remote_code=True)
|
| 17 |
+
|
| 18 |
+
queries = [
|
| 19 |
+
"What is the main cause of earthquakes?",
|
| 20 |
+
"How do sedimentary rocks form?",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
passages = [
|
| 24 |
+
"Earthquakes occur due to the sudden release of energy in the Earth's crust, often caused by tectonic plate movements along fault lines.",
|
| 25 |
+
"Sedimentary rocks form through the deposition and compaction of mineral and organic particles over time, typically in water bodies.",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
queries = [get_detailed_instruct(task_description, query) for query in queries]
|
| 29 |
+
|
| 30 |
+
q_vecs = model.encode(queries, normalize_embeddings=True)
|
| 31 |
+
p_vecs = model.encode(passages, normalize_embeddings=True)
|
| 32 |
+
|
| 33 |
+
print(np.dot(q_vecs, p_vecs.T))
|
| 34 |
+
#[[0.6369 0.2092 ]
|
| 35 |
+
# [0.2499 0.8411 ]]
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## License and Uses
|
| 39 |
+
GeoEmbedding is liscensed under Apache License 2.0, and is trained on the foundation of [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1), which is also licensed under the Apache License 2.0. It is your responsibility to ensure that your use of GeoEmbedding adheres to the terms of both the GeoEmbedding model and its upstream dependency, [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1).
|
| 40 |
+
|
| 41 |
+
The model is not intended for use in any manner that violates applicable laws or regulations, nor for any activities prohibited by the license agreement. Additionally, it should not be used in languages other than those explicitly supported, as outlined in this model card.
|
| 42 |
+
|
| 43 |
+
## Limitations
|
| 44 |
+
GeoEmbedding is trained on English datasets, and performance may be suboptimal for other languages.
|
config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MistralDualModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_mistral_dual.MistralDualConfig",
|
| 8 |
+
"AutoModel": "modeling_mistral_dual.MistralDualModel"
|
| 9 |
+
},
|
| 10 |
+
"bos_token_id": 1,
|
| 11 |
+
"eos_token_id": 2,
|
| 12 |
+
"head_dim": 128,
|
| 13 |
+
"hidden_act": "silu",
|
| 14 |
+
"hidden_size": 4096,
|
| 15 |
+
"initializer_range": 0.02,
|
| 16 |
+
"intermediate_size": 14336,
|
| 17 |
+
"max_position_embeddings": 32768,
|
| 18 |
+
"model_type": "mistraldual",
|
| 19 |
+
"num_attention_heads": 32,
|
| 20 |
+
"num_hidden_layers": 32,
|
| 21 |
+
"num_key_value_heads": 8,
|
| 22 |
+
"pad_token_id": 2,
|
| 23 |
+
"rms_norm_eps": 1e-05,
|
| 24 |
+
"rope_theta": 1000000.0,
|
| 25 |
+
"sliding_window": null,
|
| 26 |
+
"tie_word_embeddings": false,
|
| 27 |
+
"torch_dtype": "bfloat16",
|
| 28 |
+
"transformers_version": "4.51.3",
|
| 29 |
+
"use_cache": false,
|
| 30 |
+
"vocab_size": 32000
|
| 31 |
+
}
|
config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"__version__": {
|
| 3 |
+
"sentence_transformers": "3.3.1",
|
| 4 |
+
"transformers": "4.51.3",
|
| 5 |
+
"pytorch": "2.5.1+cu124"
|
| 6 |
+
},
|
| 7 |
+
"prompts": {},
|
| 8 |
+
"default_prompt_name": null,
|
| 9 |
+
"similarity_fn_name": "cosine"
|
| 10 |
+
}
|
configuration_mistral_dual.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import MistralConfig, AutoConfig
|
| 2 |
+
|
| 3 |
+
class MistralDualConfig(MistralConfig):
|
| 4 |
+
model_type = "mistraldual"
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
use_cache=False,
|
| 8 |
+
**kwargs,
|
| 9 |
+
):
|
| 10 |
+
super().__init__(use_cache=use_cache, **kwargs)
|
| 11 |
+
|
| 12 |
+
AutoConfig.register("mistraldual", MistralDualConfig)
|
| 13 |
+
MistralDualConfig.register_for_auto_class()
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:123545ec5709febf110e9a202a2d6ba4b8fe938228035d76240fc917a99db184
|
| 3 |
+
size 4943161760
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47d6b5a35b11e98b0c5f696ccdfbed64e173c173064c3407d8eb38e5ba41405b
|
| 3 |
+
size 4999818704
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:022856a3bf73e8027eb415bc64ee0efa75c124eee5f888fae8a03e01b9789482
|
| 3 |
+
size 4278371712
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 14221320192
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 7 |
+
"layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 8 |
+
"layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 9 |
+
"layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 10 |
+
"layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 11 |
+
"layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 12 |
+
"layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 13 |
+
"layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 14 |
+
"layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 15 |
+
"layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 16 |
+
"layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 17 |
+
"layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 18 |
+
"layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 19 |
+
"layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 20 |
+
"layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 21 |
+
"layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 22 |
+
"layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 23 |
+
"layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 24 |
+
"layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 25 |
+
"layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 26 |
+
"layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 27 |
+
"layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 28 |
+
"layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 29 |
+
"layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 30 |
+
"layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 31 |
+
"layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 32 |
+
"layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 33 |
+
"layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 34 |
+
"layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 35 |
+
"layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 36 |
+
"layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 37 |
+
"layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 38 |
+
"layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 39 |
+
"layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 40 |
+
"layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 41 |
+
"layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 42 |
+
"layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 43 |
+
"layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 44 |
+
"layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 45 |
+
"layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 46 |
+
"layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 47 |
+
"layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 48 |
+
"layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 49 |
+
"layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 50 |
+
"layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 51 |
+
"layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 52 |
+
"layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 53 |
+
"layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 54 |
+
"layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 55 |
+
"layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 56 |
+
"layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 57 |
+
"layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 58 |
+
"layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 59 |
+
"layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 60 |
+
"layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 61 |
+
"layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 62 |
+
"layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 63 |
+
"layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 64 |
+
"layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 65 |
+
"layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 66 |
+
"layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 67 |
+
"layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 68 |
+
"layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 69 |
+
"layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 70 |
+
"layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 71 |
+
"layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 72 |
+
"layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 73 |
+
"layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 74 |
+
"layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 75 |
+
"layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 76 |
+
"layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 77 |
+
"layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 78 |
+
"layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 79 |
+
"layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 80 |
+
"layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 81 |
+
"layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 82 |
+
"layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 83 |
+
"layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 84 |
+
"layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 85 |
+
"layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 86 |
+
"layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 87 |
+
"layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 88 |
+
"layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 89 |
+
"layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 90 |
+
"layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 91 |
+
"layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 92 |
+
"layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 93 |
+
"layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 94 |
+
"layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 95 |
+
"layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 96 |
+
"layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 97 |
+
"layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 98 |
+
"layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 99 |
+
"layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 100 |
+
"layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 101 |
+
"layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 102 |
+
"layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 103 |
+
"layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 104 |
+
"layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 105 |
+
"layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 106 |
+
"layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 107 |
+
"layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 108 |
+
"layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 109 |
+
"layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 110 |
+
"layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 111 |
+
"layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 112 |
+
"layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 113 |
+
"layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 114 |
+
"layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 115 |
+
"layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 116 |
+
"layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 117 |
+
"layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 118 |
+
"layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 119 |
+
"layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 120 |
+
"layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 121 |
+
"layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 122 |
+
"layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 123 |
+
"layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 124 |
+
"layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 125 |
+
"layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 126 |
+
"layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 127 |
+
"layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 128 |
+
"layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 129 |
+
"layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 130 |
+
"layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 131 |
+
"layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 132 |
+
"layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 133 |
+
"layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 134 |
+
"layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 135 |
+
"layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 136 |
+
"layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 137 |
+
"layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 138 |
+
"layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 139 |
+
"layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 140 |
+
"layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 141 |
+
"layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 142 |
+
"layers.22.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 143 |
+
"layers.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 144 |
+
"layers.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 145 |
+
"layers.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 146 |
+
"layers.22.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 147 |
+
"layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 148 |
+
"layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 149 |
+
"layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 150 |
+
"layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 151 |
+
"layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 152 |
+
"layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 153 |
+
"layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 154 |
+
"layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 155 |
+
"layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 156 |
+
"layers.23.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 157 |
+
"layers.23.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 158 |
+
"layers.23.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 159 |
+
"layers.23.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 160 |
+
"layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 161 |
+
"layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 162 |
+
"layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 163 |
+
"layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 164 |
+
"layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 165 |
+
"layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 166 |
+
"layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 167 |
+
"layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 168 |
+
"layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 169 |
+
"layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 170 |
+
"layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 171 |
+
"layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 172 |
+
"layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 173 |
+
"layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 174 |
+
"layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 175 |
+
"layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 176 |
+
"layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 177 |
+
"layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 178 |
+
"layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 179 |
+
"layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 180 |
+
"layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 181 |
+
"layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 182 |
+
"layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 183 |
+
"layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 184 |
+
"layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 185 |
+
"layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 186 |
+
"layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 187 |
+
"layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 188 |
+
"layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 189 |
+
"layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 190 |
+
"layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 191 |
+
"layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 192 |
+
"layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 193 |
+
"layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 194 |
+
"layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 195 |
+
"layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 199 |
+
"layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 201 |
+
"layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 203 |
+
"layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 205 |
+
"layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 206 |
+
"layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 207 |
+
"layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 211 |
+
"layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 213 |
+
"layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 214 |
+
"layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 215 |
+
"layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 216 |
+
"layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 217 |
+
"layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 218 |
+
"layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 219 |
+
"layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 220 |
+
"layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 221 |
+
"layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 222 |
+
"layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 223 |
+
"layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 224 |
+
"layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 225 |
+
"layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 226 |
+
"layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 227 |
+
"layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 228 |
+
"layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 229 |
+
"layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 230 |
+
"layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 231 |
+
"layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 232 |
+
"layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 233 |
+
"layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 234 |
+
"layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 235 |
+
"layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 236 |
+
"layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 237 |
+
"layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 238 |
+
"layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 239 |
+
"layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 240 |
+
"layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 241 |
+
"layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 242 |
+
"layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 243 |
+
"layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 244 |
+
"layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 245 |
+
"layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 246 |
+
"layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 247 |
+
"layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 248 |
+
"layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 249 |
+
"layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 250 |
+
"layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 251 |
+
"layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 252 |
+
"layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 253 |
+
"layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 254 |
+
"layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 255 |
+
"layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 256 |
+
"layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 257 |
+
"layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 258 |
+
"layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 259 |
+
"layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 260 |
+
"layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 261 |
+
"layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 262 |
+
"layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 263 |
+
"layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 264 |
+
"layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 265 |
+
"layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 266 |
+
"layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 267 |
+
"layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 268 |
+
"layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 269 |
+
"layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 270 |
+
"layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 271 |
+
"layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 272 |
+
"layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 273 |
+
"layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 274 |
+
"layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 275 |
+
"layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 276 |
+
"layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 277 |
+
"layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 278 |
+
"layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 279 |
+
"layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 280 |
+
"layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 281 |
+
"layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 282 |
+
"layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 283 |
+
"layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 284 |
+
"layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 285 |
+
"layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 286 |
+
"layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 287 |
+
"layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 288 |
+
"layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 289 |
+
"layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 290 |
+
"layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 291 |
+
"layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 292 |
+
"layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 293 |
+
"layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 294 |
+
"layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 295 |
+
"norm.weight": "model-00003-of-00003.safetensors"
|
| 296 |
+
}
|
| 297 |
+
}
|
modeling_mistral_dual.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 6 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 7 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 8 |
+
from transformers.processing_utils import Unpack
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
from transformers.models.mistral.configuration_mistral import MistralConfig
|
| 12 |
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
| 13 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
| 14 |
+
|
| 15 |
+
from .configuration_mistral_dual import MistralDualConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
class MistralDualModel(MistralModel):
|
| 20 |
+
config_class = MistralDualConfig
|
| 21 |
+
|
| 22 |
+
def __init__(self, config: MistralDualConfig):
|
| 23 |
+
super().__init__(config)
|
| 24 |
+
for layer in self.layers:
|
| 25 |
+
layer.self_attn.is_causal = False
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
input_ids: torch.LongTensor = None,
|
| 30 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 31 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 32 |
+
past_key_values: Optional[Cache] = None,
|
| 33 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 34 |
+
use_cache: Optional[bool] = None,
|
| 35 |
+
output_attentions: Optional[bool] = None,
|
| 36 |
+
output_hidden_states: Optional[bool] = None,
|
| 37 |
+
return_dict: Optional[bool] = None,
|
| 38 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 39 |
+
is_causal = False,
|
| 40 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 41 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 42 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 43 |
+
output_hidden_states = (
|
| 44 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 45 |
+
)
|
| 46 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 47 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 48 |
+
|
| 49 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 50 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 51 |
+
|
| 52 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 53 |
+
logger.warning_once(
|
| 54 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 55 |
+
)
|
| 56 |
+
use_cache = False
|
| 57 |
+
|
| 58 |
+
if inputs_embeds is None:
|
| 59 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 60 |
+
|
| 61 |
+
if use_cache and past_key_values is None:
|
| 62 |
+
past_key_values = DynamicCache()
|
| 63 |
+
|
| 64 |
+
if cache_position is None:
|
| 65 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 66 |
+
cache_position = torch.arange(
|
| 67 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if position_ids is None:
|
| 71 |
+
position_ids = cache_position.unsqueeze(0)
|
| 72 |
+
|
| 73 |
+
causal_mask = self._update_causal_mask(
|
| 74 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# print(causal_mask)
|
| 78 |
+
|
| 79 |
+
hidden_states = inputs_embeds
|
| 80 |
+
|
| 81 |
+
# create position embeddings to be shared across the decoder layers
|
| 82 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 83 |
+
|
| 84 |
+
# decoder layers
|
| 85 |
+
all_hidden_states = () if output_hidden_states else None
|
| 86 |
+
all_self_attns = () if output_attentions else None
|
| 87 |
+
|
| 88 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 89 |
+
if output_hidden_states:
|
| 90 |
+
all_hidden_states += (hidden_states,)
|
| 91 |
+
|
| 92 |
+
if self.gradient_checkpointing and self.training:
|
| 93 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 94 |
+
partial(decoder_layer.__call__, is_causal=is_causal),
|
| 95 |
+
hidden_states,
|
| 96 |
+
causal_mask,
|
| 97 |
+
position_ids,
|
| 98 |
+
past_key_values,
|
| 99 |
+
output_attentions,
|
| 100 |
+
use_cache,
|
| 101 |
+
cache_position,
|
| 102 |
+
position_embeddings,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
layer_outputs = decoder_layer(
|
| 106 |
+
hidden_states,
|
| 107 |
+
attention_mask=causal_mask,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
past_key_value=past_key_values,
|
| 110 |
+
output_attentions=output_attentions,
|
| 111 |
+
use_cache=use_cache,
|
| 112 |
+
cache_position=cache_position,
|
| 113 |
+
position_embeddings=position_embeddings,
|
| 114 |
+
is_causal=is_causal,
|
| 115 |
+
**flash_attn_kwargs,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
hidden_states = layer_outputs[0]
|
| 119 |
+
|
| 120 |
+
if output_attentions:
|
| 121 |
+
all_self_attns += (layer_outputs[1],)
|
| 122 |
+
|
| 123 |
+
hidden_states = self.norm(hidden_states)
|
| 124 |
+
|
| 125 |
+
# add hidden states from the last decoder layer
|
| 126 |
+
if output_hidden_states:
|
| 127 |
+
all_hidden_states += (hidden_states,)
|
| 128 |
+
|
| 129 |
+
output = BaseModelOutputWithPast(
|
| 130 |
+
last_hidden_state=hidden_states,
|
| 131 |
+
past_key_values=past_key_values if use_cache else None,
|
| 132 |
+
hidden_states=all_hidden_states,
|
| 133 |
+
attentions=all_self_attns,
|
| 134 |
+
)
|
| 135 |
+
return output if return_dict else output.to_tuple()
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 139 |
+
attention_mask: torch.Tensor,
|
| 140 |
+
sequence_length: int,
|
| 141 |
+
target_length: int,
|
| 142 |
+
dtype: torch.dtype,
|
| 143 |
+
device: torch.device,
|
| 144 |
+
cache_position: torch.Tensor,
|
| 145 |
+
batch_size: int,
|
| 146 |
+
config: MistralConfig,
|
| 147 |
+
past_key_values: Cache,
|
| 148 |
+
):
|
| 149 |
+
"""
|
| 150 |
+
Creates a bidirectional 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`,
|
| 151 |
+
where all tokens can attend to all others.
|
| 152 |
+
"""
|
| 153 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 154 |
+
return attention_mask # Already in correct shape
|
| 155 |
+
|
| 156 |
+
min_dtype = torch.finfo(dtype).min
|
| 157 |
+
# Create a full attention mask allowing all tokens to attend to all others
|
| 158 |
+
bidirectional_mask = torch.zeros((sequence_length, target_length), dtype=dtype, device=device)
|
| 159 |
+
bidirectional_mask = bidirectional_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 160 |
+
|
| 161 |
+
if attention_mask is not None:
|
| 162 |
+
bidirectional_mask = bidirectional_mask.clone() # Ensure contiguous memory for in-place edit
|
| 163 |
+
if attention_mask.shape[-1] > target_length:
|
| 164 |
+
attention_mask = attention_mask[:, :target_length]
|
| 165 |
+
mask_length = attention_mask.shape[-1]
|
| 166 |
+
padding_mask = bidirectional_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 167 |
+
padding_mask = padding_mask == 0
|
| 168 |
+
bidirectional_mask[:, :, :, :mask_length] = bidirectional_mask[:, :, :, :mask_length].masked_fill(
|
| 169 |
+
padding_mask, min_dtype
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return bidirectional_mask
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
AutoModel.register(MistralDualConfig, MistralDualModel)
|
| 176 |
+
MistralDualModel.register_for_auto_class()
|
modules.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.models.Transformer"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"idx": 1,
|
| 10 |
+
"name": "1",
|
| 11 |
+
"path": "1_Pooling",
|
| 12 |
+
"type": "sentence_transformers.models.Pooling"
|
| 13 |
+
}
|
| 14 |
+
]
|
sentence_bert_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_seq_length": 1024,
|
| 3 |
+
"do_lower_case": false
|
| 4 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<unk>",
|
| 4 |
+
"<s>",
|
| 5 |
+
"</s>"
|
| 6 |
+
],
|
| 7 |
+
"bos_token": {
|
| 8 |
+
"content": "<s>",
|
| 9 |
+
"lstrip": false,
|
| 10 |
+
"normalized": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"single_word": false
|
| 13 |
+
},
|
| 14 |
+
"eos_token": {
|
| 15 |
+
"content": "</s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"pad_token": {
|
| 22 |
+
"content": "</s>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false
|
| 27 |
+
},
|
| 28 |
+
"unk_token": {
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false
|
| 34 |
+
}
|
| 35 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f59f5b1915d4523382c5c42d0610fd06ad474a210c8747884d4598cf6d657331
|
| 3 |
+
size 3506438
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": true,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"additional_special_tokens": [
|
| 32 |
+
"<unk>",
|
| 33 |
+
"<s>",
|
| 34 |
+
"</s>"
|
| 35 |
+
],
|
| 36 |
+
"bos_token": "<s>",
|
| 37 |
+
"clean_up_tokenization_spaces": false,
|
| 38 |
+
"eos_token": "</s>",
|
| 39 |
+
"extra_special_tokens": {},
|
| 40 |
+
"legacy": true,
|
| 41 |
+
"max_length": 512,
|
| 42 |
+
"model_max_length": 1024,
|
| 43 |
+
"pad_to_multiple_of": null,
|
| 44 |
+
"pad_token": "</s>",
|
| 45 |
+
"pad_token_type_id": 0,
|
| 46 |
+
"padding_side": "right",
|
| 47 |
+
"sp_model_kwargs": {},
|
| 48 |
+
"spaces_between_special_tokens": false,
|
| 49 |
+
"stride": 0,
|
| 50 |
+
"tokenizer_class": "LlamaTokenizerFast",
|
| 51 |
+
"truncation_side": "right",
|
| 52 |
+
"truncation_strategy": "longest_first",
|
| 53 |
+
"unk_token": "<unk>",
|
| 54 |
+
"use_default_system_prompt": false
|
| 55 |
+
}
|