| --- |
| tags: |
| - chemistry |
| - molecule |
| - drug |
| --- |
| |
| # Roberta Zinc Decoder |
|
|
| This model is a GPT2 decoder model designed to reconstruct SMILES strings from embeddings created by the |
| [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model. The decoder model was |
| trained on 30m compounds from the [ZINC Database](https://zinc.docking.org/). |
|
|
| The decoder model conditions generation on mean pooled embeddings from the encoder model. Mean pooled |
| embeddings are used to allow for integration with vector databases, which require fixed length embeddings. |
|
|
| Condition embeddings are passed to the decoder model using the `encoder_hidden_states` attribute. |
| The standard `GPT2LMHeadModel` does not support generation with encoder hidden states, so this repo |
| includes a custom `ConditionalGPT2LMHeadModel`. See example below for how to instantiate the model. |
|
|
| ```python |
| import torch |
| from transformers import AutoModelForCausalLM, RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding |
| |
| tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=256) |
| collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt') |
| |
| encoder_model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m') |
| encoder_model.eval(); |
| |
| commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7' |
| decoder_model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder", |
| trust_remote_code=True, revision=commit_hash) |
| decoder_model.eval(); |
| |
| |
| smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1', |
| 'Brc1cc2c(NCc3ccccn3)ncnc2s1', |
| 'Brc1cc2c(NCc3cccs3)ncnc2s1', |
| 'Brc1cc2c(NCc3ccncc3)ncnc2s1', |
| 'Brc1cc2c(Nc3ccccc3)ncnc2s1'] |
| |
| inputs = collator(tokenizer(smiles)) |
| outputs = encoder_model(**inputs, output_hidden_states=True) |
| full_embeddings = outputs[1][-1] |
| mask = inputs['attention_mask'] |
| mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)) |
| |
| decoder_inputs = torch.tensor([[tokenizer.bos_token_id] for i in range(len(smiles))]) |
| |
| hidden_states = mean_embeddings[:,None] # hidden states shape (bs, 1, -1) |
| |
| gen = decoder_model.generate( |
| decoder_inputs, |
| encoder_hidden_states=hidden_states, |
| do_sample=False, # greedy decoding is recommended |
| max_length=100, |
| temperature=1., |
| early_stopping=True, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| |
| reconstructed_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True) |
| ``` |
|
|
| ## Model Performance |
|
|
| The decoder model was evaluated on a test set of 1m compounds from ZINC. Compounds |
| were encoded with the [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model |
| and reconstructed with the decoder model. |
|
|
| The following metrics are computed: |
| * `exact_match` - percent of inputs exactly reconstructed |
| * `token_accuracy` - percent of output tokens exactly matching input tokens (excluding padding) |
| * `valid_structure` - percent of generated outputs that resolved to a valid SMILES string |
| * `tanimoto` - tanimoto similarity between inputs and generated outputs. Excludes invalid structures |
| * `cos_sim` - cosine similarity between input encoder embeddings and output encoder embeddings |
|
|
| `eval_type=full` reports metrics for the full 1m compound test set. |
|
|
| `eval_type=failed` subsets metrics for generated outputs that failed to exactly replicate the inputs. |
|
|
|
|
| |eval_type|exact_match|token_accuracy|valid_structure|tanimoto|cos_sim | |
| |---------|-----------|--------------|---------------|--------|--------| |
| |full |0.948277 |0.990704 |0.994278 |0.987698|0.998224| |
| |failed |0.000000 |0.820293 |0.889372 |0.734097|0.965668| |
| |
| |
| --- |
| license: mit |
| --- |
| |