import logging from datasets import load_dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerModelCardData, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.evaluation import InformationRetrievalEvaluator from sentence_transformers.losses import MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers import logging logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) # 1. Load a model to finetune with 2. (Optional) model card data model = SentenceTransformer( "openai/clip-vit-large-patch14", model_card_data=SentenceTransformerModelCardData( language="en", license="apache-2.0", model_name="CLIP ViT-L/14 model trained on COCO Captions", ), ) # 3. Load a dataset to finetune on dataset = load_dataset("jxie/coco_captions") train_dataset = dataset["train"].select(range(10_000)) eval_dataset = dataset["validation"].select(range(1_000)) test_dataset = dataset["test"].select(range(1_000)) # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) # 5. (Optional) Specify training arguments run_name = "clip-vit-L14-coco" args = SentenceTransformerTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, warmup_ratio=0.1, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=0.1, save_strategy="steps", save_steps=0.1, save_total_limit=2, logging_steps=0.01, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. (Optional) Create an evaluator & evaluate the base model eval_queries = {qid: sample["caption"] for qid, sample in enumerate(eval_dataset)} eval_corpus = {sample["cocoid"]: sample["image"] for sample in eval_dataset} eval_relevant_docs = {qid: [sample["cocoid"]] for qid, sample in enumerate(eval_dataset)} eval_evaluator = InformationRetrievalEvaluator( queries=eval_queries, corpus=eval_corpus, relevant_docs=eval_relevant_docs, name="coco-eval", ) eval_evaluator(model) # 7. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset.select_columns(["image", "caption"]), eval_dataset=eval_dataset.select_columns(["image", "caption"]), loss=loss, evaluator=eval_evaluator, ) trainer.train() # (Optional) Evaluate the trained model on the test set test_queries = {qid: sample["caption"] for qid, sample in enumerate(test_dataset)} test_corpus = {sample["cocoid"]: sample["image"] for sample in test_dataset} test_relevant_docs = {qid: [sample["cocoid"]] for qid, sample in enumerate(test_dataset)} test_evaluator = InformationRetrievalEvaluator( queries=test_queries, corpus=test_corpus, relevant_docs=test_relevant_docs, name="coco-test", ) test_evaluator(model) # 8. Save the trained model model.save_pretrained(f"models/{run_name}/final") # 9. (Optional) Push it to the Hugging Face Hub model.push_to_hub(run_name, private=True)