Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import sys | |
| import os | |
| from pathlib import Path | |
| import importlib.util | |
| import huggingface_hub | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| import selfies as sf | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from rdkit.Chem import Descriptors, rdMolDescriptors | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| class SimpleMolecularApp: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.config = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Shared modules path | |
| self.SHARED_MODULES_DIR = Path("./shared_modules") | |
| self.SHARED_MODULES_DIR.mkdir(exist_ok=True) | |
| # Download shared modules and tokenizer files once | |
| self._ensure_shared_modules() | |
| # Supported models | |
| self.SUPPORTED_MODELS = { | |
| "Non-RL Pretrained": { | |
| "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo", | |
| "subfolder": None, | |
| "local_dir": "./chemq3_non_rl_model" | |
| }, | |
| "RL Finetuned β Step 9000": { | |
| "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo", | |
| "subfolder": "ppo_checkpoints/model_step_9000", | |
| "local_dir": "./chemq3_rlnp_step9000" | |
| }, | |
| "RL Pareto Finetuned β Step 2250": { | |
| "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints", | |
| "subfolder": "checkpoints-pareto/model_step_2250", | |
| "local_dir": "./chemq3_rlp_step2250" | |
| }, | |
| "RL Pareto Finetuned β Step 4500": { | |
| "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints", | |
| "subfolder": "checkpoints-pareto/model_step_4500", | |
| "local_dir": "./chemq3_rlp_step4500" | |
| } | |
| } | |
| def _ensure_shared_modules(self): | |
| """Download shared Python modules and tokenizer files from main repo""" | |
| print("π¦ Downloading shared modules and tokenizer files from main repo...") | |
| huggingface_hub.snapshot_download( | |
| repo_id="gbyuvd/ChemMiniQ3-SAbRLo", | |
| local_dir=str(self.SHARED_MODULES_DIR), | |
| allow_patterns=["*.py", "tokenizer*", "vocab*", "merges*", "special_tokens*", "tokenizer_config*"], | |
| resume_download=True | |
| ) | |
| print("β Shared modules and tokenizer files ready!") | |
| def load_model_by_name(self, model_key): | |
| """Load a specific model by key from SUPPORTED_MODELS""" | |
| if model_key not in self.SUPPORTED_MODELS: | |
| print(f"β Unknown model: {model_key}") | |
| return False | |
| config = self.SUPPORTED_MODELS[model_key] | |
| repo_id = config["repo_id"] | |
| subfolder = config["subfolder"] | |
| local_dir = config["local_dir"] | |
| print(f"π Loading model: {model_key} from {repo_id}") | |
| # Download model weights/config only | |
| if subfolder: | |
| allow_patterns = [ | |
| f"{subfolder}/config.json", | |
| f"{subfolder}/pytorch_model.bin", | |
| f"{subfolder}/model.safetensors", | |
| f"{subfolder}/generation_config.json" | |
| ] | |
| huggingface_hub.snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| allow_patterns=allow_patterns, | |
| resume_download=True | |
| ) | |
| model_path = Path(local_dir) / subfolder | |
| else: | |
| # Non-RL: download all files (since no subfolder) | |
| huggingface_hub.snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| resume_download=True | |
| ) | |
| model_path = Path(local_dir) | |
| if not model_path.exists(): | |
| print(f"β Model path not found: {model_path}") | |
| return False | |
| # Load custom modules from shared path | |
| loaded_modules = self.load_custom_modules_from_shared() | |
| if not loaded_modules: | |
| return False | |
| # Register model components | |
| config_class, model_class, tokenizer_class = self.register_model_components(loaded_modules) | |
| if not config_class: | |
| return False | |
| # Load model and tokenizer | |
| self.model, self.tokenizer, self.config = self.load_model_with_shared_tokenizer(model_path) | |
| if self.model is None: | |
| return False | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| print(f"β Successfully loaded: {model_key}") | |
| return True | |
| def load_custom_modules_from_shared(self): | |
| """Load custom modules from shared directory""" | |
| if str(self.SHARED_MODULES_DIR) not in sys.path: | |
| sys.path.insert(0, str(self.SHARED_MODULES_DIR)) | |
| required_files = { | |
| 'configuration_chemq3mtp.py': 'configuration_chemq3mtp', | |
| 'modeling_chemq3mtp.py': 'modeling_chemq3mtp', | |
| 'FastChemTokenizerHF.py': 'FastChemTokenizerHF' | |
| } | |
| loaded_modules = {} | |
| for filename, module_name in required_files.items(): | |
| file_path = self.SHARED_MODULES_DIR / filename | |
| if not file_path.exists(): | |
| print(f"β Required file not found in shared modules: {filename}") | |
| return None | |
| try: | |
| spec = importlib.util.spec_from_file_location(module_name, file_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| loaded_modules[module_name] = module | |
| print(f" β Loaded {filename} from shared modules") | |
| except Exception as e: | |
| print(f" β Failed to load {filename}: {e}") | |
| return None | |
| return loaded_modules | |
| def register_model_components(self, loaded_modules): | |
| """Register the model components with transformers""" | |
| try: | |
| ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig | |
| ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM | |
| FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies | |
| AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig) | |
| AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM) | |
| AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies) | |
| print("β Model components registered successfully") | |
| return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies | |
| except Exception as e: | |
| print(f"β Registration failed: {e}") | |
| return None, None, None | |
| def load_model_with_shared_tokenizer(self, model_path): | |
| """Load the model using the registered components with shared tokenizer""" | |
| try: | |
| config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(model_path), | |
| config=config, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=False | |
| ) | |
| # Use custom tokenizer class with shared tokenizer files | |
| FastChemTokenizerSelfies = self.load_custom_modules_from_shared()['FastChemTokenizerHF'].FastChemTokenizerSelfies | |
| tokenizer = FastChemTokenizerSelfies.from_pretrained(str(self.SHARED_MODULES_DIR)) | |
| return model, tokenizer, config | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| return None, None, None | |
| def calculate_lipinski_properties(self, mol): | |
| """Calculate Lipinski's Rule of Five properties""" | |
| if mol is None: | |
| return {} | |
| # Calculate molecular descriptors | |
| molecular_weight = Descriptors.MolWt(mol) | |
| h_bond_donors = rdMolDescriptors.CalcNumHBD(mol) # Hydrogen bond donors | |
| h_bond_acceptors = rdMolDescriptors.CalcNumHBA(mol) # Hydrogen bond acceptors | |
| logp = Descriptors.MolLogP(mol) # LogP (octanol-water partition coefficient) | |
| tpsa = Descriptors.TPSA(mol) # Topological Polar Surface Area | |
| rotatable_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol) | |
| heavy_atoms = mol.GetNumHeavyAtoms() | |
| # Lipinski's Rule of Five violations | |
| violations = 0 | |
| if molecular_weight > 500: violations += 1 | |
| if h_bond_donors > 5: violations += 1 | |
| if h_bond_acceptors > 10: violations += 1 | |
| if logp > 5: violations += 1 | |
| return { | |
| 'molecular_weight': round(molecular_weight, 2), | |
| 'h_bond_donors': h_bond_donors, | |
| 'h_bond_acceptors': h_bond_acceptors, | |
| 'logp': round(logp, 2), | |
| 'tpsa': round(tpsa, 2), | |
| 'rotatable_bonds': rotatable_bonds, | |
| 'heavy_atoms': heavy_atoms, | |
| 'lipinski_violations': violations | |
| } | |
| def generate_molecule(self, temperature=1.0, max_length=30, top_k=50): | |
| """Generate a complete molecule using MTP""" | |
| if self.model is None: | |
| return "Model not loaded!", None, "β Model not loaded" | |
| try: | |
| # Use the same logic as your reference code | |
| input_ids = self.tokenizer("<s>", return_tensors="pt").input_ids.to(self.device) | |
| if hasattr(self.model, 'generate_with_logprobs'): | |
| print("Using MTP-specific generation...") | |
| outputs = self.model.generate_with_logprobs( | |
| input_ids, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| do_sample=True, | |
| return_probs=True, | |
| tokenizer=self.tokenizer | |
| ) | |
| # Extract tokens from MTP output (index 2) | |
| gen_tokens = outputs[2] | |
| else: | |
| print("Using standard generation...") | |
| gen_tokens = self.model.generate( | |
| input_ids, | |
| max_length=input_ids.shape[1] + max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else 0 | |
| ) | |
| # Decode the generated molecule | |
| generatedmol = self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True) | |
| selfies_str = generatedmol.replace(' ', '') | |
| smiles = sf.decoder(selfies_str) | |
| info_text = f"Generated SELFIES: {selfies_str}\n" | |
| info_text += f"Decoded SMILES: {smiles}\n" | |
| # Visualize molecule | |
| mol_image = None | |
| property_text = "" | |
| if smiles: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| # Generate molecule image | |
| img = Draw.MolToImage(mol, size=(400, 400)) | |
| mol_image = img | |
| # Calculate Lipinski properties | |
| props = self.calculate_lipinski_properties(mol) | |
| property_text = "π§ͺ Molecular Properties (Lipinski's Rule of Five):\n" | |
| property_text += f"β’ Molecular Weight: {props['molecular_weight']} g/mol\n" | |
| property_text += f"β’ H-Bond Donors: {props['h_bond_donors']}\n" | |
| property_text += f"β’ H-Bond Acceptors: {props['h_bond_acceptors']}\n" | |
| property_text += f"β’ LogP: {props['logp']}\n" | |
| property_text += f"β’ TPSA: {props['tpsa']} Γ Β²\n" | |
| property_text += f"β’ Rotatable Bonds: {props['rotatable_bonds']}\n" | |
| property_text += f"β’ Heavy Atoms: {props['heavy_atoms']}\n" | |
| property_text += f"β’ Lipinski Violations: {props['lipinski_violations']}/4\n" | |
| # Rule of Five assessment | |
| if props['lipinski_violations'] <= 1: | |
| property_text += "β Drug-like molecule (Lipinski compliant)" | |
| else: | |
| property_text += f"β οΈ May have poor bioavailability ({props['lipinski_violations']} violations)" | |
| info_text += "β Valid molecule generated!" | |
| else: | |
| property_text = "β οΈ Could not calculate properties - invalid SMILES structure" | |
| info_text += "β οΈ Invalid SMILES structure" | |
| else: | |
| property_text = "β οΈ Could not calculate properties - could not decode to SMILES" | |
| info_text += "β οΈ Could not decode to SMILES" | |
| return info_text, mol_image, property_text | |
| except Exception as e: | |
| return f"Error generating molecule: {str(e)}", None, "β Error calculating properties" | |
| def create_simple_interface(): | |
| """Create the simplified Gradio interface""" | |
| app = SimpleMolecularApp() | |
| # Preload default model (Non-RL) | |
| default_model = "Non-RL Pretrained" | |
| print(f"Initializing default model: {default_model}") | |
| if not app.load_model_by_name(default_model): | |
| print("Failed to initialize default model!") | |
| return None | |
| print("Model initialized successfully!") | |
| with gr.Blocks(title="π§ͺ ChemMiniQ3-SAbRLo Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π§ͺ ChemMiniQ3-SAbRLo Demo | |
| Generate molecules using either the **Non-RL pretrained model** or **RL-finetuned checkpoints** | |
| optimized with a **ParetoRewards controller**. | |
| """) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=list(app.SUPPORTED_MODELS.keys()), | |
| value=default_model, | |
| label="Select Model" | |
| ) | |
| load_btn = gr.Button("π Load Selected Model", variant="secondary") | |
| # Model status indicator | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value=f"β Current Model: {default_model}", | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| # Generation controls | |
| with gr.Row(): | |
| with gr.Column(): | |
| temp_slider = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=1.0, | |
| label="Temperature", info="Higher = more random", | |
| step=0.1 | |
| ) | |
| length_slider = gr.Slider( | |
| minimum=10, maximum=50, value=30, | |
| label="Max Length", info="Max tokens to generate", | |
| step=1, precision=0 | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=10, maximum=100, value=50, | |
| label="Top-K", info="Sampling diversity", | |
| step=1, precision=0 | |
| ) | |
| generate_btn = gr.Button("π§ͺ Generate Molecule", variant="primary") | |
| with gr.Column(): | |
| mol_info = gr.Textbox( | |
| label="Molecule Information", | |
| lines=5, | |
| interactive=False | |
| ) | |
| mol_image = gr.Image( | |
| label="Generated Molecule", | |
| type="pil" | |
| ) | |
| # Molecular properties section | |
| property_info = gr.Textbox( | |
| label="Molecular Properties (Lipinski's Rule of Five)", | |
| lines=10, | |
| interactive=False | |
| ) | |
| def load_model_wrapper(model_name): | |
| success = app.load_model_by_name(model_name) | |
| if success: | |
| status = f"β Current Model: {model_name} (Ready to use!)" | |
| else: | |
| status = f"β Failed to load: {model_name}" | |
| return status | |
| load_btn.click( | |
| fn=load_model_wrapper, | |
| inputs=model_choice, | |
| outputs=model_status | |
| ) | |
| # Generate molecule | |
| generate_btn.click( | |
| fn=app.generate_molecule, | |
| inputs=[temp_slider, length_slider, topk_slider], | |
| outputs=[mol_info, mol_image, property_info] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [1.0, 30, 50], | |
| [0.8, 25, 40], | |
| [1.5, 35, 60], | |
| ], | |
| inputs=[temp_slider, length_slider, topk_slider], | |
| fn=app.generate_molecule, | |
| outputs=[mol_info, mol_image, property_info], | |
| cache_examples=False # Disable if model can change | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_simple_interface() | |
| if demo: | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| else: | |
| print("Failed to create interface!") |