File size: 17,211 Bytes
f567466
 
 
 
 
 
 
 
 
 
 
30aefbc
f567466
 
 
 
ad83ee7
f567466
 
 
 
 
2776189
9a77142
 
 
 
34887f1
9a77142
 
2776189
 
 
 
 
 
 
 
 
 
 
 
9a77142
2776189
0600503
2776189
9a77142
 
 
0600503
9a77142
2776189
 
ad83ee7
9a77142
34887f1
 
9a77142
 
 
34887f1
9a77142
 
34887f1
9a77142
2776189
 
 
 
 
 
 
 
 
 
 
 
 
9a77142
2776189
9a77142
 
 
 
34887f1
9a77142
2776189
 
f567466
2776189
f567466
 
2776189
 
9a77142
2776189
 
 
 
 
 
 
 
 
 
 
9a77142
 
2776189
 
 
 
 
 
 
 
34887f1
 
2776189
 
 
 
 
 
 
f567466
9a77142
 
 
 
f567466
 
 
 
 
 
 
 
 
9a77142
f567466
9a77142
f567466
 
 
 
 
 
9a77142
f567466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34887f1
 
f567466
 
 
 
 
 
 
 
34887f1
 
 
f567466
 
 
 
 
30aefbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad83ee7
 
f3eb784
30aefbc
f3eb784
 
ad83ee7
 
f3eb784
 
ad83ee7
f3eb784
ad83ee7
 
 
 
 
f3eb784
 
 
ad83ee7
 
f3eb784
ad83ee7
 
 
 
 
 
 
 
 
f567466
ad83ee7
 
 
 
f567466
ad83ee7
 
f567466
ad83ee7
f567466
30aefbc
 
ad83ee7
 
 
30aefbc
ad83ee7
 
30aefbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad83ee7
f567466
30aefbc
ad83ee7
 
30aefbc
ad83ee7
f567466
30aefbc
f567466
 
30aefbc
f567466
ad83ee7
 
 
f567466
2776189
 
 
 
 
f567466
 
2776189
 
f567466
2776189
f567466
2776189
 
f567466
2776189
 
 
 
 
 
 
 
 
26a5cc7
 
 
 
 
 
 
 
2776189
f567466
ad83ee7
 
 
ff7c83f
2776189
ad83ee7
 
 
ff7c83f
2776189
ad83ee7
 
 
ff7c83f
2776189
f567466
ad83ee7
2776189
ad83ee7
 
f567466
ad83ee7
f567466
 
ad83ee7
 
f567466
 
2776189
 
 
 
 
 
 
 
 
 
26a5cc7
 
 
 
2776189
 
 
 
 
26a5cc7
2776189
 
 
ad83ee7
 
 
30aefbc
f567466
2776189
ad83ee7
 
2776189
 
 
ad83ee7
 
 
30aefbc
2776189
87ba4cd
2776189
f567466
 
 
ad83ee7
f567466
ad83ee7
f567466
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import gradio as gr
import torch
import sys
import os
from pathlib import Path
import importlib.util
import huggingface_hub
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import Descriptors, rdMolDescriptors
import numpy as np
from PIL import Image
import io

class SimpleMolecularApp:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.config = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Shared modules path
        self.SHARED_MODULES_DIR = Path("./shared_modules")
        self.SHARED_MODULES_DIR.mkdir(exist_ok=True)
        
        # Download shared modules and tokenizer files once
        self._ensure_shared_modules()
        
        # Supported models
        self.SUPPORTED_MODELS = {
            "Non-RL Pretrained": {
                "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo",
                "subfolder": None,
                "local_dir": "./chemq3_non_rl_model"
            },
            "RL Finetuned – Step 9000": {
                "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo",
                "subfolder": "ppo_checkpoints/model_step_9000",
                "local_dir": "./chemq3_rlnp_step9000"
            },
            "RL Pareto Finetuned – Step 2250": {
                "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints",
                "subfolder": "checkpoints-pareto/model_step_2250",
                "local_dir": "./chemq3_rlp_step2250"
            },
            "RL Pareto Finetuned – Step 4500": {
                "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints",
                "subfolder": "checkpoints-pareto/model_step_4500",
                "local_dir": "./chemq3_rlp_step4500"
            }
        }

    def _ensure_shared_modules(self):
        """Download shared Python modules and tokenizer files from main repo"""
        print("πŸ“¦ Downloading shared modules and tokenizer files from main repo...")
        huggingface_hub.snapshot_download(
            repo_id="gbyuvd/ChemMiniQ3-SAbRLo",
            local_dir=str(self.SHARED_MODULES_DIR),
            allow_patterns=["*.py", "tokenizer*", "vocab*", "merges*", "special_tokens*", "tokenizer_config*"],
            resume_download=True
        )
        print("βœ… Shared modules and tokenizer files ready!")

    def load_model_by_name(self, model_key):
        """Load a specific model by key from SUPPORTED_MODELS"""
        if model_key not in self.SUPPORTED_MODELS:
            print(f"❌ Unknown model: {model_key}")
            return False

        config = self.SUPPORTED_MODELS[model_key]
        repo_id = config["repo_id"]
        subfolder = config["subfolder"]
        local_dir = config["local_dir"]

        print(f"πŸ”„ Loading model: {model_key} from {repo_id}")
        
        # Download model weights/config only
        if subfolder:
            allow_patterns = [
                f"{subfolder}/config.json",
                f"{subfolder}/pytorch_model.bin",
                f"{subfolder}/model.safetensors",
                f"{subfolder}/generation_config.json"
            ]
            huggingface_hub.snapshot_download(
                repo_id=repo_id,
                local_dir=local_dir,
                allow_patterns=allow_patterns,
                resume_download=True
            )
            model_path = Path(local_dir) / subfolder
        else:
            # Non-RL: download all files (since no subfolder)
            huggingface_hub.snapshot_download(
                repo_id=repo_id,
                local_dir=local_dir,
                resume_download=True
            )
            model_path = Path(local_dir)

        if not model_path.exists():
            print(f"❌ Model path not found: {model_path}")
            return False

        # Load custom modules from shared path
        loaded_modules = self.load_custom_modules_from_shared()
        if not loaded_modules:
            return False

        # Register model components
        config_class, model_class, tokenizer_class = self.register_model_components(loaded_modules)
        if not config_class:
            return False

        # Load model and tokenizer
        self.model, self.tokenizer, self.config = self.load_model_with_shared_tokenizer(model_path)
        if self.model is None:
            return False

        self.model = self.model.to(self.device)
        self.model.eval()
        print(f"βœ… Successfully loaded: {model_key}")
        return True

    def load_custom_modules_from_shared(self):
        """Load custom modules from shared directory"""
        if str(self.SHARED_MODULES_DIR) not in sys.path:
            sys.path.insert(0, str(self.SHARED_MODULES_DIR))
        
        required_files = {
            'configuration_chemq3mtp.py': 'configuration_chemq3mtp',
            'modeling_chemq3mtp.py': 'modeling_chemq3mtp', 
            'FastChemTokenizerHF.py': 'FastChemTokenizerHF'
        }
        
        loaded_modules = {}
        for filename, module_name in required_files.items():
            file_path = self.SHARED_MODULES_DIR / filename
            if not file_path.exists():
                print(f"❌ Required file not found in shared modules: {filename}")
                return None
            try:
                spec = importlib.util.spec_from_file_location(module_name, file_path)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                loaded_modules[module_name] = module
                print(f"   βœ… Loaded {filename} from shared modules")
            except Exception as e:
                print(f"   ❌ Failed to load {filename}: {e}")
                return None
        return loaded_modules

    def register_model_components(self, loaded_modules):
        """Register the model components with transformers"""
        try:
            ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig
            ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM
            FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies
            
            AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
            AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)
            AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)
            
            print("βœ… Model components registered successfully")
            return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies
        except Exception as e:
            print(f"❌ Registration failed: {e}")
            return None, None, None

    def load_model_with_shared_tokenizer(self, model_path):
        """Load the model using the registered components with shared tokenizer"""
        try:
            config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False)
            model = AutoModelForCausalLM.from_pretrained(
                str(model_path),
                config=config,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=False
            )
            # Use custom tokenizer class with shared tokenizer files
            FastChemTokenizerSelfies = self.load_custom_modules_from_shared()['FastChemTokenizerHF'].FastChemTokenizerSelfies
            tokenizer = FastChemTokenizerSelfies.from_pretrained(str(self.SHARED_MODULES_DIR))
            return model, tokenizer, config
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            return None, None, None

    def calculate_lipinski_properties(self, mol):
        """Calculate Lipinski's Rule of Five properties"""
        if mol is None:
            return {}
        
        # Calculate molecular descriptors
        molecular_weight = Descriptors.MolWt(mol)
        h_bond_donors = rdMolDescriptors.CalcNumHBD(mol)  # Hydrogen bond donors
        h_bond_acceptors = rdMolDescriptors.CalcNumHBA(mol)  # Hydrogen bond acceptors
        logp = Descriptors.MolLogP(mol)  # LogP (octanol-water partition coefficient)
        tpsa = Descriptors.TPSA(mol)  # Topological Polar Surface Area
        rotatable_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol)
        heavy_atoms = mol.GetNumHeavyAtoms()
        
        # Lipinski's Rule of Five violations
        violations = 0
        if molecular_weight > 500: violations += 1
        if h_bond_donors > 5: violations += 1
        if h_bond_acceptors > 10: violations += 1
        if logp > 5: violations += 1
        
        return {
            'molecular_weight': round(molecular_weight, 2),
            'h_bond_donors': h_bond_donors,
            'h_bond_acceptors': h_bond_acceptors,
            'logp': round(logp, 2),
            'tpsa': round(tpsa, 2),
            'rotatable_bonds': rotatable_bonds,
            'heavy_atoms': heavy_atoms,
            'lipinski_violations': violations
        }

    def generate_molecule(self, temperature=1.0, max_length=30, top_k=50):
        """Generate a complete molecule using MTP"""
        if self.model is None:
            return "Model not loaded!", None, "❌ Model not loaded"
        
        try:
            # Use the same logic as your reference code
            input_ids = self.tokenizer("<s>", return_tensors="pt").input_ids.to(self.device)
            
            if hasattr(self.model, 'generate_with_logprobs'):
                print("Using MTP-specific generation...")
                outputs = self.model.generate_with_logprobs(
                    input_ids,
                    max_new_tokens=max_length,
                    temperature=temperature,
                    top_k=top_k,
                    do_sample=True,
                    return_probs=True,
                    tokenizer=self.tokenizer
                )
                # Extract tokens from MTP output (index 2)
                gen_tokens = outputs[2]
            else:
                print("Using standard generation...")
                gen_tokens = self.model.generate(
                    input_ids,
                    max_length=input_ids.shape[1] + max_length,
                    temperature=temperature,
                    top_k=top_k,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else 0
                )
            
            # Decode the generated molecule
            generatedmol = self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
            selfies_str = generatedmol.replace(' ', '')
            smiles = sf.decoder(selfies_str)
            
            info_text = f"Generated SELFIES: {selfies_str}\n"
            info_text += f"Decoded SMILES: {smiles}\n"
            
            # Visualize molecule
            mol_image = None
            property_text = ""
            
            if smiles:
                mol = Chem.MolFromSmiles(smiles)
                if mol:
                    # Generate molecule image
                    img = Draw.MolToImage(mol, size=(400, 400))
                    mol_image = img
                    
                    # Calculate Lipinski properties
                    props = self.calculate_lipinski_properties(mol)
                    
                    property_text = "πŸ§ͺ Molecular Properties (Lipinski's Rule of Five):\n"
                    property_text += f"β€’ Molecular Weight: {props['molecular_weight']} g/mol\n"
                    property_text += f"β€’ H-Bond Donors: {props['h_bond_donors']}\n"
                    property_text += f"β€’ H-Bond Acceptors: {props['h_bond_acceptors']}\n"
                    property_text += f"β€’ LogP: {props['logp']}\n"
                    property_text += f"β€’ TPSA: {props['tpsa']} Γ…Β²\n"
                    property_text += f"β€’ Rotatable Bonds: {props['rotatable_bonds']}\n"
                    property_text += f"β€’ Heavy Atoms: {props['heavy_atoms']}\n"
                    property_text += f"β€’ Lipinski Violations: {props['lipinski_violations']}/4\n"
                    
                    # Rule of Five assessment
                    if props['lipinski_violations'] <= 1:
                        property_text += "βœ… Drug-like molecule (Lipinski compliant)"
                    else:
                        property_text += f"⚠️  May have poor bioavailability ({props['lipinski_violations']} violations)"
                    
                    info_text += "βœ… Valid molecule generated!"
                else:
                    property_text = "⚠️ Could not calculate properties - invalid SMILES structure"
                    info_text += "⚠️ Invalid SMILES structure"
            else:
                property_text = "⚠️ Could not calculate properties - could not decode to SMILES"
                info_text += "⚠️ Could not decode to SMILES"
            
            return info_text, mol_image, property_text
            
        except Exception as e:
            return f"Error generating molecule: {str(e)}", None, "❌ Error calculating properties"

def create_simple_interface():
    """Create the simplified Gradio interface"""
    app = SimpleMolecularApp()
    
    # Preload default model (Non-RL)
    default_model = "Non-RL Pretrained"
    print(f"Initializing default model: {default_model}")
    if not app.load_model_by_name(default_model):
        print("Failed to initialize default model!")
        return None
    print("Model initialized successfully!")

    with gr.Blocks(title="πŸ§ͺ ChemMiniQ3-SAbRLo Demo", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # πŸ§ͺ ChemMiniQ3-SAbRLo Demo
        
        Generate molecules using either the **Non-RL pretrained model** or **RL-finetuned checkpoints** 
        optimized with a **ParetoRewards controller**.
        """)

        with gr.Row():
            model_choice = gr.Dropdown(
                choices=list(app.SUPPORTED_MODELS.keys()),
                value=default_model,
                label="Select Model"
            )
            load_btn = gr.Button("πŸ” Load Selected Model", variant="secondary")

        # Model status indicator
        model_status = gr.Textbox(
            label="Model Status",
            value=f"βœ… Current Model: {default_model}",
            interactive=False,
            show_copy_button=True
        )

        # Generation controls
        with gr.Row():
            with gr.Column():
                temp_slider = gr.Slider(
                    minimum=0.1, maximum=2.0, value=1.0, 
                    label="Temperature", info="Higher = more random",
                    step=0.1
                )
                length_slider = gr.Slider(
                    minimum=10, maximum=50, value=30, 
                    label="Max Length", info="Max tokens to generate",
                    step=1, precision=0
                )
                topk_slider = gr.Slider(
                    minimum=10, maximum=100, value=50, 
                    label="Top-K", info="Sampling diversity",
                    step=1, precision=0
                )
                generate_btn = gr.Button("πŸ§ͺ Generate Molecule", variant="primary")

            with gr.Column():
                mol_info = gr.Textbox(
                    label="Molecule Information",
                    lines=5,
                    interactive=False
                )
                mol_image = gr.Image(
                    label="Generated Molecule",
                    type="pil"
                )

        # Molecular properties section
        property_info = gr.Textbox(
            label="Molecular Properties (Lipinski's Rule of Five)",
            lines=10,
            interactive=False
        )

        def load_model_wrapper(model_name):
            success = app.load_model_by_name(model_name)
            if success:
                status = f"βœ… Current Model: {model_name} (Ready to use!)"
            else:
                status = f"❌ Failed to load: {model_name}"
            return status

        load_btn.click(
            fn=load_model_wrapper,
            inputs=model_choice,
            outputs=model_status
        )

        # Generate molecule
        generate_btn.click(
            fn=app.generate_molecule,
            inputs=[temp_slider, length_slider, topk_slider],
            outputs=[mol_info, mol_image, property_info]
        )

        gr.Examples(
            examples=[
                [1.0, 30, 50],
                [0.8, 25, 40],
                [1.5, 35, 60],
            ],
            inputs=[temp_slider, length_slider, topk_slider],
            fn=app.generate_molecule,
            outputs=[mol_info, mol_image, property_info],
            cache_examples=False  # Disable if model can change
        )

    return demo

if __name__ == "__main__":
    demo = create_simple_interface()
    if demo:
        demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
    else:
        print("Failed to create interface!")