In [None]:
from generators import *
from utils import *

In [16]:
import json



system_message = """
You are provided with three components:
1. A verifier
2. A generator
3. Global variables and auxiliary functions that both the verifier and generator depend on

The generator is used to create two matrices, while the verifier is a transformation rule that checks if the transformation from input to output is valid.

Note that the two input parameters of the generator, diff_lb and diff_ub, are used to control the generation difficulty, especially the size of the generated grid. These parameters are important, and special attention should be paid to how they are used to adjust the difficulty when designing new generators.

Your tasks are:
1. Explain the effects of the original verifier and generator, as well as their relationship.
2. Referring to the provided verifier and generator, and using only the provided auxiliary functions, create a new pair of verifier and generator that are matched. 
The new transformation rule from input to output should be different from the original and contain exactly {num_steps} steps.
3. Ensure that your new transformation rule is both simple and interesting.

(Note: Only use the auxiliary functions provided here, and do not use any other functions or variables. Carefully review all available auxiliary functions to make full use of them.)

Your response should be in JSON format. Please ensure it can be parsed by json.loads(). Provide your answer using the following structure:
{{
  "original_reasoning": "Explanation of the original transformation reasoning",
  "new_verifier_reasoning": "[Your step-by-step reasoning about the new verifier, explaining how each step contributes to the overall transformation]",
  "new_verifier_code": "```python\\ndef verify_new(I: Grid) -> Grid:\\n    # Your verifier code here\\n```",
  "new_generator_reasoning": "[Your step-by-step reasoning about the new generator, including considerations for different difficulty levels]",
  "new_generator_code": "```python\\ndef generate_new(diff_lb: float, diff_ub: float) -> dict:\\n    # Your generator code here\\n```",
  "transformation_description": "A concise description of the new transformation rule",
  "num_of_rules": "{num_steps}"
}}

Break down complex problems into smaller parts and reason through them step by step, arriving at sub-conclusions before stating an overall conclusion. This reduces the extent to which you need to do large leaps of reasoning.

Reason in substantial detail as necessary to determine the transformation rule. Consider potential errors or edge cases and how to handle them.

Be creative and accomplished at solving puzzles. Here are some prompts to inspire creativity:
- How could you incorporate multiple colors or shapes?
- Can you create a rule that involves rotation or reflection?
- Is there a way to make the transformation dependent on the object's position in the grid?

Remember to thoroughly test your verifier and generator to ensure they work correctly together for various inputs.
"""


# Now you can use formatted_system_message in your code
# Read all data and save
with open('/mnt/data/zifeng.cao/reasoning/arc-agi/re-arc/origin_code.jsonl', 'r') as f:
    data_list = [json.loads(line) for line in f]

name_freq = {'a61ba2ce': 0.002962726983115642, 'd406998b': 0.002962726983115642, 'f8ff0b80': 0.002954762663268557, 'feca6190': 0.002954762663268557, '681b3aeb': 0.002946798343421472, '1f642eb9': 0.002946798343421472, 'a699fb00': 0.0029388340235743868, '3befdf3e': 0.0029388340235743868, 'd5d6de2d': 0.0029388340235743868, '1e0a9b12': 0.0029388340235743868, '0962bcdd': 0.0029308697037273017, '7ddcd7ec': 0.0029308697037273017, '239be575': 0.0029308697037273017, '97999447': 0.0029308697037273017, 'aedd82e4': 0.0029229053838802166, '6d0160f0': 0.0029229053838802166, '810b9b61': 0.0029229053838802166, '137eaa0f': 0.0029149410640331315, '1caeab9d': 0.0029149410640331315, '54d9e175': 0.0029149410640331315, 'a1570a43': 0.0029149410640331315, '543a7ed5': 0.0029149410640331315, 'd89b689b': 0.0029149410640331315, '95990924': 0.0029069767441860465, 'e76a88a6': 0.0029069767441860465, '1f876c06': 0.0029069767441860465, '57aa92db': 0.0028990124243389614, '56ff96f3': 0.0028990124243389614, 'd23f8c26': 0.0028990124243389614, 'd43fd935': 0.0028990124243389614, '8eb1be9a': 0.0028990124243389614, '868de0fa': 0.0028990124243389614, 'b6afb2da': 0.0028990124243389614, 'd4469b4b': 0.0028910481044918763, '776ffc46': 0.0028910481044918763, 'c0f76784': 0.0028910481044918763, '6ecd11f4': 0.0028910481044918763, '746b3537': 0.0028910481044918763, '444801d8': 0.0028910481044918763, '53b68214': 0.0028910481044918763, '39a8645d': 0.0028910481044918763, '88a10436': 0.0028910481044918763, '025d127b': 0.0028830837846447912, '00d62c1b': 0.0028830837846447912, 'fcc82909': 0.0028830837846447912, '7e0986d6': 0.0028830837846447912, 'bb43febb': 0.0028830837846447912, '7f4411dc': 0.0028830837846447912, 'f1cefba8': 0.0028830837846447912, 'a85d4709': 0.0028830837846447912, '60b61512': 0.0028830837846447912, '1cf80156': 0.0028830837846447912, '25ff71a9': 0.002875119464797706, 'f35d900a': 0.002875119464797706, '5c0a986e': 0.002875119464797706, '1f0c79e5': 0.002875119464797706, '4347f46a': 0.002875119464797706, 'be94b721': 0.002867155144950621, '41e4d17e': 0.002867155144950621, '5614dbcf': 0.002867155144950621, '99fa7670': 0.002867155144950621, 'b775ac94': 0.002867155144950621, '6cdd2623': 0.002867155144950621, 'ef135b50': 0.002867155144950621, 'd2abd087': 0.002867155144950621, '1c786137': 0.002859190825103536, '44d8ac46': 0.002859190825103536, 'bbc9ae5d': 0.002859190825103536, '27a28665': 0.002859190825103536, '1f85a75f': 0.002851226505256451, '7468f01a': 0.002851226505256451, '8f2ea7aa': 0.002851226505256451, 'c909285e': 0.002851226505256451, '97a05b5b': 0.002851226505256451, '913fb3ed': 0.002851226505256451, '22233c11': 0.002851226505256451, '995c5fa3': 0.002851226505256451, 'fcb5c309': 0.002843262185409366, '23b5c85d': 0.002843262185409366, 'ba26e723': 0.002843262185409366, 'd037b0a7': 0.002843262185409366, 'a740d043': 0.002843262185409366, 'cbded52d': 0.002843262185409366, '890034e9': 0.0028352978655622808, 'f8c80d96': 0.0028352978655622808, '045e512c': 0.0028352978655622808, '662c240a': 0.0028352978655622808, '11852cab': 0.0028352978655622808, 'b94a9452': 0.0028352978655622808, '46f33fce': 0.002827333545715196, '6e19193c': 0.002827333545715196, '40853293': 0.002827333545715196, '5168d44c': 0.002827333545715196, 'f8b3ba0a': 0.002827333545715196, 'b548a754': 0.002827333545715196, 'e21d9049': 0.002827333545715196, 'bda2d7a6': 0.002819369225868111, 'f9012d9b': 0.002819369225868111, 'ec883f72': 0.002819369225868111, '3bdb4ada': 0.002819369225868111, '6e82a1ae': 0.002819369225868111, '80af3007': 0.002819369225868111, '9f236235': 0.002811404906021026, 'd13f3404': 0.002811404906021026, '3de23699': 0.002803440586173941, 'bdad9b1f': 0.002803440586173941, '9565186b': 0.002803440586173941, '72ca375d': 0.002803440586173941, 'c1d99e64': 0.002803440586173941, '264363fd': 0.002795476266326856, 'b230c067': 0.002795476266326856, '91413438': 0.002795476266326856, 'f25ffba3': 0.002795476266326856, 'cf98881b': 0.002795476266326856, 'bc1d5164': 0.002795476266326856, '447fd412': 0.002795476266326856, 'c444b776': 0.0027875119464797708, 'ce22a75a': 0.0027875119464797708, '09629e4f': 0.0027875119464797708, '3e980e27': 0.0027875119464797708, 'f5b8619d': 0.0027875119464797708, '36fdfd69': 0.0027875119464797708, '85c4e7cd': 0.0027795476266326857, 'b9b7f026': 0.0027795476266326857, 'd06dbe63': 0.0027795476266326857, '91714a58': 0.0027795476266326857, 'a3325580': 0.0027795476266326857, '2dc579da': 0.0027795476266326857, 'b27ca6d3': 0.0027715833067856006, '496994bd': 0.0027715833067856006, '6773b310': 0.0027715833067856006, 'ac0a08a4': 0.0027636189869385155, '228f6490': 0.0027636189869385155, 'd687bc17': 0.0027636189869385155, 'ff28f65a': 0.0027636189869385155, 'de1cd16c': 0.0027556546670914304, '93b581b8': 0.0027556546670914304, 'db93a21d': 0.0027556546670914304, '8e1813be': 0.0027556546670914304, '2204b7a8': 0.0027556546670914304, '6b9890af': 0.0027556546670914304, '321b1fc6': 0.0027556546670914304, 'e509e548': 0.0027556546670914304, '5117e062': 0.0027476903472443454, '32597951': 0.0027476903472443454, '6c434453': 0.0027476903472443454, '25d487eb': 0.0027476903472443454, '7b6016b9': 0.0027397260273972603, '2bee17df': 0.0027397260273972603, '6a1e5592': 0.0027317617075501752, '63613498': 0.00272379738770309, 'ecdecbb3': 0.00272379738770309, '7df24a62': 0.00272379738770309, '48d8fb45': 0.00272379738770309, '9aec4887': 0.00272379738770309, 'c8cbb738': 0.00272379738770309, '67385a82': 0.00272379738770309, '8efcae92': 0.00272379738770309, '484b58aa': 0.002715833067856005, '28bf18c6': 0.002715833067856005, '928ad970': 0.002715833067856005, '150deff5': 0.002715833067856005, 'e6721834': 0.00270786874800892, 'f8a8fe49': 0.00270786874800892, 'a78176bb': 0.00270786874800892, 'd511f180': 0.00270786874800892, '56dc2b01': 0.00270786874800892, '855e0971': 0.002699904428161835, '3f7978a0': 0.002699904428161835, 'e50d258f': 0.002699904428161835, 'a79310a0': 0.002699904428161835, '3aa6fb7a': 0.002699904428161835, '72322fa7': 0.002699904428161835, 'b0c4d837': 0.002699904428161835, '445eab21': 0.00269194010831475, '0b148d64': 0.00269194010831475, '08ed6ac7': 0.00269194010831475, 'cce03e0d': 0.0026839757884676648, '29ec7d0e': 0.0026839757884676648, 'c3e719e8': 0.0026760114686205797, 'a5f85a15': 0.0026760114686205797, 'ae3edfdc': 0.0026760114686205797, 'a8d7556c': 0.0026760114686205797, '846bdb03': 0.0026680471487734946, 'e8593010': 0.0026680471487734946, 'd07ae81c': 0.0026680471487734946, '1e32b0e9': 0.0026680471487734946, '952a094c': 0.0026680471487734946, 'b527c5c6': 0.0026680471487734946, '50846271': 0.0026680471487734946, '5c2c9af4': 0.0026680471487734946, '29623171': 0.0026600828289264095, '9edfc990': 0.0026600828289264095, '4522001f': 0.0026600828289264095, '8a004b2b': 0.0026600828289264095, 'b190f7f5': 0.0026600828289264095, '4c4377d9': 0.0026521185090793245, 'a48eeaf7': 0.0026521185090793245, '4258a5f9': 0.0026521185090793245, '834ec97d': 0.0026521185090793245, '05269061': 0.0026521185090793245, 'caa06a1f': 0.0026521185090793245, '1b60fb0c': 0.0026441541892322394, '363442ee': 0.0026441541892322394, '06df4c85': 0.0026441541892322394, 'aabf363d': 0.0026441541892322394, '29c11459': 0.0026441541892322394, 'd9f24cd1': 0.0026441541892322394, '39e1d7f9': 0.0026361898693851543, 'a68b268e': 0.0026361898693851543, '90f3ed37': 0.0026282255495380697, '8403a5d5': 0.0026282255495380697, 'e73095fd': 0.0026282255495380697, '50cb2852': 0.0026282255495380697, '1190e5a7': 0.0026282255495380697, '5ad4f10b': 0.0026282255495380697, 'f2829549': 0.0026202612296909846, 'ce602527': 0.0026202612296909846, '36d67576': 0.0026202612296909846, 'f76d97a5': 0.0026202612296909846, '83302e8f': 0.0026122969098438995, 'f25fbde4': 0.0026122969098438995, 'e26a3af2': 0.0026122969098438995, '67e8384a': 0.0026043325899968144, 'e9614598': 0.0026043325899968144, '90c28cc7': 0.0026043325899968144, 'ce4f8723': 0.0026043325899968144, 'e9afcf9a': 0.0026043325899968144, '22eb0ac0': 0.0026043325899968144, '1fad071e': 0.0025963682701497293, '6455b5f5': 0.0025884039503026443, '9af7a82c': 0.0025884039503026443, 'e98196ab': 0.0025884039503026443, '694f12f3': 0.0025884039503026443, 'e48d4e1a': 0.002580439630455559, '4c5c2cf0': 0.002580439630455559, '234bbc79': 0.002580439630455559, '8d510a79': 0.002580439630455559, '0dfd9992': 0.002572475310608474, 'ce9e57f2': 0.002572475310608474, '22168020': 0.002572475310608474, 'd9fac9be': 0.002564510990761389, 'a5313dff': 0.002564510990761389, 'b1948b0a': 0.002564510990761389, '6aa20dc0': 0.002564510990761389, '5521c0d9': 0.002556546670914304, 'd0f5fe59': 0.002556546670914304, 'ae4f1146': 0.002556546670914304, '6855a6e4': 0.002556546670914304, '017c7c7b': 0.002556546670914304, '4290ef0e': 0.002548582351067219, '82819916': 0.002540618031220134, '673ef223': 0.002540618031220134, '5582e5ca': 0.002540618031220134, '760b3cac': 0.002540618031220134, 'ea32f347': 0.002540618031220134, '794b24be': 0.0025326537113730487, '780d0b14': 0.0025326537113730487, 'c3f564a4': 0.0025246893915259637, '6e02f1e3': 0.0025246893915259637, '6430c8c4': 0.0025246893915259637, '1b2d62fb': 0.0025246893915259637, 'd4f3cd78': 0.0025246893915259637, '8e5a5113': 0.0025087607518317935, '74dd1130': 0.0025007964319847084, '44f52bb0': 0.0025007964319847084, 'b91ae062': 0.0024928321121376234, '68b16354': 0.0024928321121376234, '4093f84a': 0.0024848677922905383, 'a87f7484': 0.0024848677922905383, 'b2862040': 0.0024848677922905383, 'e3497940': 0.002476903472443453, '6d58a25d': 0.002476903472443453, '3428a4f5': 0.002476903472443453, 'af902bf9': 0.002476903472443453, '941d9a10': 0.002468939152596368, '1a07d186': 0.002468939152596368, 'fafffa47': 0.002468939152596368, '67a3c6ac': 0.002460974832749283, '0d3d703e': 0.002460974832749283, '963e52fc': 0.002453010512902198, '8be77c9e': 0.002453010512902198, 'dc1df850': 0.002429117553360943, 'd8c310e9': 0.002429117553360943, '4be741c5': 0.002421153233513858, '7837ac64': 0.002421153233513858, '7447852a': 0.002413188913666773, 'dc433765': 0.002405224593819688, 'ddf7fa4f': 0.002397260273972603, '54d82841': 0.002397260273972603, '0a938d79': 0.002397260273972603, '9172f3a0': 0.002397260273972603, 'd6ad076f': 0.002389295954125518, 'd22278a0': 0.0023733673144313476, 'd631b094': 0.0023654029945842626, 'e40b9e2f': 0.0023574386747371775, 'cdecee7f': 0.0023574386747371775, '9dfd6313': 0.0023494743548900924, '4612dd53': 0.0023415100350430073, '2c608aff': 0.0023415100350430073, 'a61f2674': 0.0023415100350430073, 'dc0a314f': 0.0023415100350430073, '3345333e': 0.0023335457151959223, '3bd67248': 0.0023335457151959223, '5daaa586': 0.0023335457151959223, '6cf79266': 0.0023335457151959223, 'c59eb873': 0.0023335457151959223, '2bcee788': 0.0023335457151959223, '2dee498d': 0.002325581395348837, 'ea786f4a': 0.002317617075501752, '23581191': 0.002317617075501752, '6d0aefbc': 0.002317617075501752, 'e8dc4411': 0.002317617075501752, 'a3df8b1e': 0.002309652755654667, 'a65b410d': 0.002285759796113412, '3631a71a': 0.002285759796113412, '8731374e': 0.002285759796113412, '0520fde7': 0.002285759796113412, 'e5062a87': 0.0022777954762663267, 'a8c38be5': 0.0022777954762663267, '3618c87e': 0.0022698311564192416, '75b8110e': 0.0022539025167250715, 'c8f0f002': 0.0022539025167250715, 'db3e9e38': 0.0022459381968779864, 'dae9d2b5': 0.0022300095571838167, '67a423a3': 0.0022300095571838167, '007bbfb7': 0.0022220452373367316, 'ded97339': 0.0022220452373367316, '0e206a2e': 0.0022140809174896465, 'ba97ae07': 0.0022140809174896465, '42a50994': 0.0022061165976425615, '1bfc4729': 0.0022061165976425615, '3eda0437': 0.0021981522777954764, 'beb8660c': 0.0021981522777954764, '623ea044': 0.0021981522777954764, 'eb281b96': 0.0021822236381013062, '2281f1f4': 0.002174259318254221, '3ac3eb23': 0.002166294998407136, '62c24649': 0.002150366358712966, 'a64e4611': 0.0021264733991717107, '469497ad': 0.0020786874800892002, '178fcbfb': 0.0020786874800892002, '99b1bc43': 0.0020786874800892002, '10fcaaa3': 0.0020786874800892002, '6d75e8bb': 0.002054794520547945, '77fdfe62': 0.0020468302007008604, 'eb5a1d5d': 0.0020388658808537753, '3906de3d': 0.00201497292131252, '25d8a9c8': 0.001991079961771265, '4938f0c2': 0.0019751513220770947, '28e73c20': 0.0019671870022300096, 'aba27056': 0.0019592226823829245, '253bf280': 0.0019432940426887544, '05f2a901': 0.0019353297228416693, 'b782dc8a': 0.0019273654029945842, '7b7f7511': 0.0019273654029945842, '8d5021e8': 0.0019273654029945842, 'e179c5f4': 0.0019194010831474991, '98cf29f8': 0.0018955081236062441, '2dd70a9a': 0.001887543803759159, '7c008303': 0.001879579483912074, 'a416b8f3': 0.0018477222045237337, '272f95fa': 0.0018477222045237337, 'b60334d2': 0.0017999362854412234, 'c9e6f938': 0.0017840076457470533, '49d1d64f': 0.001752150366358713, 'b7249182': 0.001752150366358713, '9d9215db': 0.0017441860465116279, 'bd4472b8': 0.0017362217266645428, '94f9d214': 0.0017202930869703727, '2013d3e2': 0.0017202930869703727, '0ca9ddb6': 0.0016964001274291176, 'b8cdaf2b': 0.0016964001274291176, '6fa7a44f': 0.0016964001274291176, '508bd3b6': 0.0016645428480407773, 'b8825c91': 0.0016565785281936923, 'd10ecb37': 0.001632685568652437, '6f8cd79b': 0.001624721248805352, 'dbc1a6ce': 0.0015928639694170119, '3af2c5a8': 0.0015610066900286716, '73251a56': 0.0014574705320165658, 'f15e1fac': 0.0014495062121694807, 'ed36ccf7': 0.0014017202930869705, 'a9f96cdd': 0.0013698630136986301, '539a4f51': 0.0013698630136986301, '47c1f68c': 0.00135393437400446, 'a2fd1cf0': 0.0012424338961452691, 'ff805c23': 0.001226505256451099, '88a62173': 0.0011707550175215037, 'd4a91cb9': 0.0011150047785919083, '3c9b0459': 0.0011070404587448233, '5bd6f4ac': 0.0010911118190506531, '31aa019c': 0.0010433258999681427, 'd364b489': 0.000955718381650207, '9ecd008a': 0.0009158967824147818, 'd90796e8': 0.0009079324625676967, 'c9f8e694': 0.0006530742274609749, '7fe24cdd': 0.0006212169480726346, '6150a2bd': 0.0003743230328129978, '46442a0e': 0.0}
# Create messages for each data
entries = []
for data in data_list:
    if data["name"] not in name_freq:
        name_repeat = 30
    else:
        name_repeat = int(3000 * name_freq[data["name"]] + 1)
    for i in range(name_repeat):
        user_message = f"Here are the components:\n\nVerifier:\n{data['verifier']}\n\nGenerator:\n{data['generator']}\n\nGlobal Variables and Auxiliary Functions:\n{data['global_variable']}\n{data['additional_functions']}"
        
        num_steps_config = {2: 1, 3: 2, 4: 2, 5: 2, 6: 1, 7: 1}
        
        for num_steps, repeat_times in num_steps_config.items():
            for _ in range(repeat_times):
                messages = [
                    {
                        "role": "system", 
                        "content": system_message.format(num_steps=num_steps)
                    },
                    {
                        "role": "user",
                        "content": user_message
                    }
                ]
                
                # 准备每个数据作为一行JSON
                entry = {
                    "name": f"{data['name']}_{num_steps}steps",
                    "messages": messages
                }
                entries.append(entry)

# import numpy as np
# # 最终对entries进行随机打乱
# random.shuffle(entries)

# Write all entries at once
with open('generator_verifier_regenerator_prompt_simplified_fixed_freq.jsonl', 'w', encoding='utf-8') as file:
    for entry in entries:
        json.dump(entry, file, ensure_ascii=False)
        file.write('\n')

print("All dialogue entries have been successfully archived in generator_verifier_regenerator_prompt_simplified.jsonl")


All dialogue entries have been successfully archived in generator_verifier_regenerator_prompt_simplified.jsonl


In [None]:
import numpy as np
import importlib
import random

def import_generator_and_verifier(id):
    # 从generators.py导入生成器函数
    generators_module = importlib.import_module('generators')
    generator_function = getattr(generators_module, f'generate_{id}')
    
    # 从verifiers.py导入验证器函数  
    verifiers_module = importlib.import_module('verifiers')
    verifier_function = getattr(verifiers_module, f'verify_{id}')
    
    return generator_function, verifier_function

t_id = '3631a71a'
generate_func, verify_func = import_generator_and_verifier(t_id)


verifier_output = []
result_dict_list = []
for i in range(5):
    result_dict = generate_func(0.5, 0.5)
    result_dict["output"] = verify_func(result_dict['input'])
    verifier_output.append(result_dict["output"])
    result_dict = {k: np.array(v) for k, v in result_dict.items()}
    result_dict_list.append(result_dict)

plot_task(result_dict_list)

# for i in range(5):
#     assert np.array_equal(result_dict_list[i]['input'], verifier_output[i])

# 获取verify_func的具体实现
import inspect
verify_func_source = inspect.getsource(verify_func)
print("verify_func的具体实现:")
print(verify_func_source)




In [None]:
import ast
import astor

class FunctionSplitter(ast.NodeTransformer):
    def __init__(self):
        self.new_functions = []

    def visit_FunctionDef(self, node):
        # 检查函数是否返回Grid
        returns_grid = any(isinstance(child, ast.Return) and 
                           isinstance(child.value, ast.Name) and 
                           child.value.id == 'Grid' 
                           for child in ast.walk(node))
        
        if not returns_grid:
            return node

        # 创建步骤函数
        step_functions = []
        for i, stmt in enumerate(node.body[:-1], start=1):
            step_func = ast.FunctionDef(
                name=f"{node.name}_step{i}",
                args=node.args,
                body=node.body[:i+1],  # 包括当前语句及之前的所有语句
                decorator_list=[],
                returns=node.returns
            )
            step_functions.append(step_func)

        # 修改原始函数
        node.body = [
            ast.Return(
                value=ast.Call(
                    func=ast.Name(id=f"{node.name}_step{len(step_functions)}", ctx=ast.Load()),
                    args=[ast.Name(id='I', ctx=ast.Load())],
                    keywords=[]
                )
            )
        ]
        
        self.new_functions.extend(step_functions)
        return node

def split_functions(source_code):
    tree = ast.parse(source_code)
    transformer = FunctionSplitter()
    transformed_tree = transformer.visit(tree)
    
    # 将新函数添加到AST
    transformed_tree.body.extend(transformer.new_functions)
    
    return astor.to_source(transformed_tree)

# 使用verify_func_source作为输入
transformed_code = split_functions(verify_func_source)

print("转换后的代码:")
print(transformed_code)

In [None]:
import os
import re
import json

from generators import *
from utils import *

def read_json_files(folder_path):
    result = {}
    for filename in os.listdir(folder_path):
        if filename.endswith('.json'):
            file_path = os.path.join(folder_path, filename)
            try:
                with open(file_path, 'r', encoding='utf-8') as file:
                    content = json.load(file)
                    result[re.sub(r'\.json$', '', filename)] = content
            except json.JSONDecodeError:
                print(f"警告: 文件 '{filename}' 不是有效的JSON格式。已跳过。")
            except Exception as e:
                print(f"读取文件 '{filename}' 时发生错误: {str(e)}")
    return result


raw_training_data = read_json_files('/mnt/data/zifeng.cao/reasoning/arc-agi/re-arc/arc_original/training')
original_training_io_data = {k: [format_example(example) for example in v["train"]] for k, v in raw_training_data.items()}

raw_generator_data = read_json_files('/mnt/data/zifeng.cao/reasoning/arc-agi/re-arc/re-arc-4-example/tasks')
original_generator_io_data = {k: [format_example(example) for example in v] for k, v in raw_generator_data.items()}

import pickle

# 将数据保存到pickle文件
with open('io_data.pkl', 'wb') as f:
    pickle.dump({
        'original_training_io_data': original_training_io_data,
        'original_generator_io_data': original_generator_io_data
    }, f)



In [6]:
import json
import os
import re
import signal
import resource
import multiprocessing
from generators import *
from utils import *
import hashlib

def stable_hash(obj):
    return hashlib.md5(json.dumps(obj, sort_keys=True).encode()).hexdigest()


def get_jsonl_files(directory):
    """
    获取指定目录下所有的.jsonl文件列表
    """
    jsonl_files = []
    for filename in os.listdir(directory):
        if filename.endswith('.jsonl'):
            file_path = os.path.join(directory, filename)
            jsonl_files.append(file_path)
            break
    return jsonl_files

def read_jsonl_file(file_path):
    """
    读取单个.jsonl文件内容
    """
    with open(file_path, 'r') as file:
        return file.readlines()

def extract_json_from_line(line):
    """
    从单行内容中提取特定的JSON部分
    """
    try:
        # 解析整行为JSON
        data = json.loads(line)
        
        # 检查是否存在'response'键
        if 'response' in data:
            # 使用正则表达式提取JSON部分
            match = re.search(r'\{.*\}', data['response'])
            if match:
                json_str = match.group()
                # 尝试解析提取的JSON字符串
                json_data = json.loads(json_str)
                return json_data
    except json.JSONDecodeError:
        # 如果解析失败,返回None
        return None

def extract_json_from_content(content_list):
    """
    从内容列表中提取特定的JSON部分
    """
    extracted_data = []
    for line in content_list:
        json_data = extract_json_from_line(line)
        if json_data:
            extracted_data.append(json_data)
    
    return extracted_data


def timeout_handler(signum, frame):
    raise TimeoutError("Execution timeout")


def extract_code_from_line(line):
    data = json.loads(line)
    json_data = None
    
    if 'response' in data:
        match = re.search(r'\{(?=.*"original_reasoning":)(?=.*"num_of_rules":).*\}', data['response'], re.DOTALL)
        if match:
            json_str = match.group()
            try:
                json_data = json.loads(json_str)
            except json.JSONDecodeError:
                json_data = None

    origin_name = data.get("name", "").split("_")[0]

    if json_data is None or len(origin_name) == 0:
        return None

    verifier_code = json_data.get('new_verifier_code', '')
    generator_code = json_data.get('new_generator_code', '')

    verifier_code = re.sub(r'^```(?:python)?\s*|\s*```$', '', verifier_code.strip())
    verifier_code = verifier_code.replace('\\\\n', '\n').replace('\\n', '\n')

    generator_code = re.sub(r'^```(?:python)?\s*|\s*```$', '', generator_code.strip())
    generator_code = generator_code.replace('\\\\n', '\n').replace('\\n', '\n')

    num_of_rules = json_data.get("num_of_rules", 0)
    transformation_description = json_data.get("transformation_description", "")


    return {
        "origin_name": origin_name,
        "num_of_rules": num_of_rules,
        "verifier_code": verifier_code,
        "generator_code": generator_code,
        "transformation_description": transformation_description,
    }


def execute_code(code, timeout, memory_limit=10*1024*1024):  # 默认100MB内存限制
    def isolated_execution():
        def limit_memory():
            resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))

        try:
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(timeout)
            limit_memory()
            # 导入模块
            import generators
            import utils
            # 创建全局命名空间，包含模块的命名空间
            exec_globals = {}
            exec_globals.update(generators.__dict__)
            exec_globals.update(utils.__dict__)
            # 执行代码
            exec(code, exec_globals)
            signal.alarm(0)
            return True
        except TimeoutError:
            print("代码执行超时")
            return False
        except MemoryError:
            print("内存超出限制")
            return False
        except Exception as e:
            # print(f"代码执行错误: {e}")
            return False

    process = multiprocessing.Process(target=isolated_execution)
    process.start()
    process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        print("代码执行超时")
        return False

    return process.exitcode == 0

def execute_function(func_name, args, code, timeout, memory_limit=10*1024*1024):  # 默认100MB内存限制
    def isolated_execution(queue):
        def limit_memory():
            resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))

        try:
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(timeout)
            limit_memory()
            # 导入模块
            import generators
            import utils
            # 创建全局命名空间，包含模块的命名空间
            exec_globals = {}
            exec_globals.update(generators.__dict__)
            exec_globals.update(utils.__dict__)
            # 执行代码
            exec(code, exec_globals)
            # 调用函数
            result = exec_globals[func_name](*args)
            signal.alarm(0)
            queue.put(result)
        except TimeoutError:
            print(f"{func_name} 执行超时")
            queue.put(None)
        except MemoryError:
            print(f"{func_name} 内存超出限制")
            queue.put(None)
        except Exception as e:
            # print(f"代码执行错误: {e}")
            queue.put(None)

    result_queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=isolated_execution, args=(result_queue,))
    process.start()
    process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        print(f"{func_name} 执行超时")
        return None

    return result_queue.get() if not result_queue.empty() else None

def check_generator_verifier_match(generator_code, verifier_code, timeout):
    generator_func_match = re.search(r'def\s+(\w+)', generator_code)
    verifier_func_match = re.search(r'def\s+(\w+)', verifier_code)
    if generator_func_match is None or verifier_func_match is None:
        return False, None

    generator_func_name = generator_func_match.group(1)
    verifier_func_name = verifier_func_match.group(1)

    # 执行生成器代码并获取结果
    generator_result = execute_function(generator_func_name, (0.5, 0.5), generator_code, timeout)
    if generator_result is None or "input" not in generator_result or "output" not in generator_result:
        return False, None

    # 合并生成器和验证器代码，确保验证器能访问必要的依赖
    combined_code = generator_code + '\n' + verifier_code

    # 调用验证器函数
    verifier_output = (verifier_func_name, (generator_result['input'],), combined_code, timeout)
    if verifier_output is None:
        return False, None

    return generator_result['output'] == verifier_output, get_pso_difficulty(generator_result)

def check_verifier_on_data(verifier_code, io_data, timeout):
    verifier_func_match = re.search(r'def\s+(\w+)', verifier_code)
    if verifier_func_match is None:
        return False, None, None

    verifier_func_name = verifier_func_match.group(1)

    for io_item in io_data:
        # 调用验证器函数
        verifier_output = execute_function(verifier_func_name, (io_item['input'],), verifier_code, timeout)
        if verifier_output is None or not is_grid(verifier_output):
            return False, None, None
    return True, stable_hash(verifier_output), get_pso_difficulty(io_item)

def execute_and_evaluate(extracted_data, original_training_io_data, original_generator_io_data, timeout):
    if not extracted_data:
        return None

    origin_name = extracted_data["origin_name"]
    verifier_code = extracted_data["verifier_code"]
    generator_code = extracted_data["generator_code"]

    origin_training_io = original_training_io_data[origin_name]
    origin_generator_io = original_generator_io_data[origin_name]

    generator_verifier_match, generator_difficulty = check_generator_verifier_match(generator_code, verifier_code, timeout)
    origin_training_verifier_success, origin_training_verifier_hash, origin_training_verifier_difficulty = check_verifier_on_data(verifier_code, origin_training_io, timeout)
    origin_generator_verifier_success, origin_generator_verifier_hash, origin_generator_verifier_difficulty = check_verifier_on_data(verifier_code, origin_generator_io, timeout)

    return {
        "new_generator_verifier_match": generator_verifier_match,
        "new_generator_difficulty": generator_difficulty,
        "origin_training_verifier_success": origin_training_verifier_success,
        "origin_generator_verifier_success": origin_generator_verifier_success,
        "origin_training_verifier_hash": origin_training_verifier_hash,
        "origin_generator_verifier_hash": origin_generator_verifier_hash,
        "origin_training_verifier_difficulty": origin_training_verifier_difficulty,
        "origin_generator_verifier_difficulty": origin_generator_verifier_difficulty,
    }


In [7]:
import pickle

# 从pickle文件读取数据
with open('io_data.pkl', 'rb') as f:
    loaded_data = pickle.load(f)

original_training_io_data = loaded_data['original_training_io_data']
original_generator_io_data = loaded_data['original_generator_io_data']

print("数据已成功加载")

数据已成功加载


In [8]:
# 使用示例
directory = '/mnt/data/zifeng.cao/reasoning/arc-agi/rollout/sampling_code/arc_new_rule_verifier/Qwen2.5-72B-Instruct_ARC_NEW_RULE_1024_SEQ-LEN_8192_temperature_0.7_world-size_8_n-worker-per-node_4'

# 获取所有.jsonl文件列表
jsonl_files = get_jsonl_files(directory)

for file in jsonl_files:
    file_contents = read_jsonl_file(file)
    file_contents = file_contents[:10]
    results = []
    from concurrent.futures import ThreadPoolExecutor, as_completed
    import multiprocessing

    def process_line(line, original_training_io_data, original_generator_io_data):
        try:
            print(line)
            extracted_data = extract_code_from_line(line)
            print(extracted_data)
            if extracted_data:
                result = execute_and_evaluate(extracted_data, original_training_io_data, original_generator_io_data, timeout=1)
                if result:
                    if not (result["new_generator_verifier_match"] or result["origin_training_verifier_success"] or result["origin_generator_verifier_success"]):
                        return None
                    return {
                        "name": extracted_data["origin_name"],
                        "num_of_rules": extracted_data["num_of_rules"],
                        "verifier_code": extracted_data["verifier_code"],   
                        "generator_code": extracted_data["generator_code"],
                        "generator_difficulty": result["new_generator_difficulty"],
                        "transformation_description": extracted_data["transformation_description"],
                        "origin_training_verifier_output_hash": result["origin_training_verifier_hash"],
                        "origin_training_verifier_difficulty": result["origin_training_verifier_difficulty"],
                        "origin_generator_verifier_output_hash": result["origin_generator_verifier_hash"],
                        "origin_generator_verifier_difficulty": result["origin_generator_verifier_difficulty"],
                    }
        except Exception as e:
            print(f"处理行时发生错误: {str(e)}")
            return None
        return None


    results = []
    max_workers = multiprocessing.cpu_count()
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_line = {executor.submit(process_line, line, original_training_io_data, original_generator_io_data): line for i, line in enumerate(file_contents)}
        for future in as_completed(future_to_line):
            result = future.result()
            if result:
                results.append(result)

    # if len(results) > 0:
        # new_file_name = file.rsplit('.', 1)[0] + '_verify.jsonl'
        # with open(new_file_name, "w") as f:
        #     print(new_file_name)
        #     for result in results:
        #         f.write(json.dumps(result) + "\n")
    print(len(file_contents), len(results))
    break

{"name": "4258a5f9_4steps", "response": "```json\n{\n  \"original_reasoning\": \"The original verifier and generator work together to create a transformation where the least common color in the input grid is identified, and then the cells of that color and their neighbors are filled with a specific value (1). The generator creates a grid with a background color and places a certain number of dots of a different color. The verifier then checks if the transformation is applied correctly by filling the neighbors of the dots with the value 1.\",\n  \"new_verifier_reasoning\": \"The new verifier will perform a transformation that involves rotating the grid by 90 degrees clockwise, then reflecting it horizontally, and finally filling the cells that were originally on the boundary of the grid with a specific color (2). The steps are as follows:\\n1. Rotate the grid 90 degrees clockwise.\\n2. Reflect the grid horizontally.\\n3. Identify the cells that were originally on the boundary of the gri

In [6]:
print(result)

{'new_generator_verifier_match': False, 'origin_training_verifier_success': False, 'origin_generator_verifier_success': False}
