| | """ |
| | Token position definitions for MCQA task submission. |
| | This file provides token position functions that identify key tokens in MCQA prompts. |
| | """ |
| |
|
| | import re |
| | from CausalAbstraction.model_units.LM_units import TokenPosition |
| |
|
| |
|
| | def get_last_token_index(prompt, pipeline): |
| | """ |
| | Get the index of the last token in the prompt. |
| | |
| | Args: |
| | prompt (str): The input prompt |
| | pipeline: The tokenizer pipeline |
| | |
| | Returns: |
| | list[int]: List containing the index of the last token |
| | """ |
| | input_ids = list(pipeline.load(prompt)["input_ids"][0]) |
| | return [len(input_ids) - 1] |
| |
|
| |
|
| | def get_correct_symbol_index(prompt, pipeline, task): |
| | """ |
| | Find the index of the correct answer symbol in the prompt. |
| | |
| | Args: |
| | prompt (str): The prompt text |
| | pipeline: The tokenizer pipeline |
| | task: The task object containing causal model |
| | |
| | Returns: |
| | list[int]: List containing the index of the correct answer symbol token |
| | """ |
| | |
| | output = task.causal_model.run_forward(task.input_loader(prompt)) |
| | pointer = output["answer_pointer"] |
| | correct_symbol = output[f"symbol{pointer}"] |
| | |
| | |
| | matches = list(re.finditer(r"\b[A-Z]\b", prompt)) |
| | |
| | |
| | symbol_match = None |
| | for match in matches: |
| | if prompt[match.start():match.end()] == correct_symbol: |
| | symbol_match = match |
| | break |
| | |
| | if not symbol_match: |
| | raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") |
| | |
| | |
| | substring = prompt[:symbol_match.end()] |
| | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) |
| | |
| | |
| | return [len(tokenized_substring) - 1] |
| |
|
| |
|
| | def get_token_positions(pipeline, task): |
| | """ |
| | Get token positions for the MCQA task. |
| | |
| | This function identifies key token positions in MCQA prompts: |
| | - correct_symbol: The position of the correct answer symbol (A, B, C, or D) |
| | - last_token: The position of the last token in the prompt |
| | |
| | Args: |
| | pipeline: The language model pipeline with tokenizer |
| | task: The MCQA task object |
| | |
| | Returns: |
| | list[TokenPosition]: List of TokenPosition objects for intervention experiments |
| | """ |
| | |
| | token_positions = [ |
| | TokenPosition( |
| | lambda x: get_correct_symbol_index(x, pipeline, task), |
| | pipeline, |
| | id="correct_symbol" |
| | ), |
| | TokenPosition( |
| | lambda x: get_last_token_index(x, pipeline), |
| | pipeline, |
| | id="last_token" |
| | ) |
| | ] |
| | return token_positions |