| from typing import Any, Dict, List, Tuple | |
| from copy import deepcopy | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from .WISE import WISE | |
| from .utils import tokenize, get_context_templates | |
| from .wise_hparams import WISEHyperParams | |
| import gradio as gr | |
| def apply_wise_to_model( | |
| model: AutoModelForCausalLM, | |
| tok: AutoTokenizer, | |
| request: List[Dict], | |
| hparams: WISEHyperParams, | |
| num_steps: int, | |
| edit_lr: float, | |
| copy=False, | |
| return_orig_weights=False, | |
| keep_original_weight=False, | |
| **kwargs: Any, | |
| ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: | |
| if copy: | |
| model = deepcopy(model) | |
| weights_copy = {} | |
| hparams.n_iter = num_steps | |
| hparams.edit_lr = edit_lr | |
| context_templates = get_context_templates(model, tok, length_params=[[5,5], [10,5]], device=hparams.device) | |
| editor = WISE(model=model, config=hparams, device=hparams.device) | |
| print( | |
| f"Executing WISE algorithm for the update: " | |
| f"[{request['prompt']}] -> [{request['target_new']}]" | |
| ) | |
| tokens, act_mask, deact_mask = tokenize(request, tokenizer=tok, device=hparams.device, context_templates=context_templates, hparams=hparams) | |
| editor.edit(config=hparams, tokens=tokens, act_mask=act_mask, deact_mask=deact_mask) | |
| editor.to('cpu') | |
| gr.Info("Completed editing via WISE!") | |
| return editor |