| | --- |
| | tags: |
| | - chemistry |
| | - molecule |
| | - drug |
| | --- |
| | |
| | # Roberta Zinc Decoder |
| |
|
| | This model is a GPT2 decoder model designed to reconstruct SMILES strings from embeddings created by the |
| | [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model. The decoder model was |
| | trained on 30m compounds from the [ZINC Database](https://zinc.docking.org/). |
| |
|
| | The decoder model conditions generation on mean pooled embeddings from the encoder model. Mean pooled |
| | embeddings are used to allow for integration with vector databases, which require fixed length embeddings. |
| |
|
| | Condition embeddings are passed to the decoder model using the `encoder_hidden_states` attribute. |
| | The standard `GPT2LMHeadModel` does not support generation with encoder hidden states, so this repo |
| | includes a custom `ConditionalGPT2LMHeadModel`. See example below for how to instantiate the model. |
| |
|
| | ```python |
| | import torch |
| | from transformers import AutoModelForCausalLM, RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding |
| | |
| | tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=256) |
| | collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt') |
| | |
| | encoder_model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m') |
| | encoder_model.eval(); |
| | |
| | commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7' |
| | decoder_model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder", |
| | trust_remote_code=True, revision=commit_hash) |
| | decoder_model.eval(); |
| | |
| | |
| | smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1', |
| | 'Brc1cc2c(NCc3ccccn3)ncnc2s1', |
| | 'Brc1cc2c(NCc3cccs3)ncnc2s1', |
| | 'Brc1cc2c(NCc3ccncc3)ncnc2s1', |
| | 'Brc1cc2c(Nc3ccccc3)ncnc2s1'] |
| | |
| | inputs = collator(tokenizer(smiles)) |
| | outputs = encoder_model(**inputs, output_hidden_states=True) |
| | full_embeddings = outputs[1][-1] |
| | mask = inputs['attention_mask'] |
| | mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)) |
| | |
| | decoder_inputs = torch.tensor([[tokenizer.bos_token_id] for i in range(len(smiles))]) |
| | |
| | hidden_states = mean_embeddings[:,None] # hidden states shape (bs, 1, -1) |
| | |
| | gen = decoder_model.generate( |
| | decoder_inputs, |
| | encoder_hidden_states=hidden_states, |
| | do_sample=False, # greedy decoding is recommended |
| | max_length=100, |
| | temperature=1., |
| | early_stopping=True, |
| | pad_token_id=tokenizer.pad_token_id, |
| | ) |
| | |
| | reconstructed_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True) |
| | ``` |
| |
|
| | ## Model Performance |
| |
|
| | The decoder model was evaluated on a test set of 1m compounds from ZINC. Compounds |
| | were encoded with the [roberta_zinc_480m](https://huggingface.co/entropy/roberta_zinc_480m) model |
| | and reconstructed with the decoder model. |
| |
|
| | The following metrics are computed: |
| | * `exact_match` - percent of inputs exactly reconstructed |
| | * `token_accuracy` - percent of output tokens exactly matching input tokens (excluding padding) |
| | * `valid_structure` - percent of generated outputs that resolved to a valid SMILES string |
| | * `tanimoto` - tanimoto similarity between inputs and generated outputs. Excludes invalid structures |
| | * `cos_sim` - cosine similarity between input encoder embeddings and output encoder embeddings |
| |
|
| | `eval_type=full` reports metrics for the full 1m compound test set. |
| |
|
| | `eval_type=failed` subsets metrics for generated outputs that failed to exactly replicate the inputs. |
| |
|
| |
|
| | |eval_type|exact_match|token_accuracy|valid_structure|tanimoto|cos_sim | |
| | |---------|-----------|--------------|---------------|--------|--------| |
| | |full |0.948277 |0.990704 |0.994278 |0.987698|0.998224| |
| | |failed |0.000000 |0.820293 |0.889372 |0.734097|0.965668| |
| | |
| | |
| | --- |
| | license: mit |
| | --- |
| | |