| from datasets import load_dataset | |
| from unfat.datasets import hub_prompts, HubSplit, Dataset, Prompts | |
| from unfat.extract import Extractor, ClientOpts | |
| from unfat.lora import LoraSettings | |
| import os | |
| output_dir="output" | |
| uncensor_ds_name = "Guilherme34/uncensor" | |
| uncensor_ds = load_dataset(uncensor_ds_name, split="train") | |
| def uncensor_items(): | |
| for row in uncensor_ds: | |
| for message in row["messages"]: | |
| if message["role"] == "user": | |
| yield message["content"] | |
| break | |
| extractor = Extractor( | |
| teacher="hf:mlabonne/Llama-3.1-70B-Instruct-lorablated", | |
| max_concurrent=8, | |
| output_dir=output_dir, | |
| client_opts=ClientOpts( | |
| base_url="https://glhf.chat/api/openai/v1", | |
| api_key=os.environ["GLHF_API_KEY"], | |
| ), | |
| dataset=Dataset( | |
| train=[ | |
| Prompts( | |
| output_path=f"hub/{uncensor_ds_name}.jsonl", | |
| count=lambda: len(uncensor_ds), | |
| items=uncensor_items, | |
| ), | |
| hub_prompts( | |
| name="mlabonne/harmful_behaviors", | |
| text_field="text", | |
| split=HubSplit(name="train"), | |
| ), | |
| ], | |
| eval=[ | |
| hub_prompts( | |
| name="mlabonne/harmful_behaviors", | |
| text_field="text", | |
| split=HubSplit(name="test"), | |
| ), | |
| ], | |
| ), | |
| ) | |
| lora_settings = LoraSettings( | |
| lora_r=32, | |
| lora_alpha=16, | |
| lora_dropout=0.01, | |
| num_epochs=2, | |
| learning_rate=4e-4, | |
| warmup_steps=10, | |
| ) | |
| axolotl_config = lora_settings.llama_70b_axolotl(extractor.output_dataset()) | |
| extractor.run() | |
| axolotl_config.save(output_dir) | |