File size: 73,364 Bytes
ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 1bc7e54 ea0b2a0 1bc7e54 ea0b2a0 454e47c ea0b2a0 454e47c 1bc7e54 ea0b2a0 7837959 1bc7e54 ea0b2a0 454e47c 1bc7e54 729d700 454e47c ea0b2a0 1bc7e54 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 1bc7e54 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 1bc7e54 454e47c 1bc7e54 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 454e47c 7837959 454e47c 7837959 454e47c 7837959 454e47c 7837959 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 729d700 454e47c ea0b2a0 454e47c 7837959 0dbb356 7837959 0dbb356 7837959 454e47c 0dbb356 ea0b2a0 1bc7e54 7837959 0dbb356 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 1bc7e54 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 1bc7e54 ea0b2a0 454e47c 1bc7e54 ea0b2a0 1bc7e54 729d700 7837959 729d700 1bc7e54 ea0b2a0 729d700 7837959 8083c06 729d700 7837959 ea0b2a0 7837959 ea0b2a0 729d700 ea0b2a0 729d700 8083c06 729d700 8083c06 729d700 8083c06 729d700 7837959 ea0b2a0 7837959 ea0b2a0 7837959 ea0b2a0 7837959 8083c06 7837959 8083c06 7837959 ea0b2a0 729d700 ea0b2a0 7837959 1bc7e54 729d700 7837959 729d700 ea0b2a0 729d700 1bc7e54 729d700 1bc7e54 729d700 7837959 729d700 0dbb356 729d700 1bc7e54 729d700 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 454e47c 7837959 454e47c ea0b2a0 454e47c 7837959 ea0b2a0 454e47c 7837959 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 454e47c 7837959 ea0b2a0 454e47c 7837959 729d700 7837959 729d700 7837959 729d700 7837959 729d700 7837959 729d700 7837959 729d700 7837959 ea0b2a0 7837959 ea0b2a0 7837959 ea0b2a0 7837959 454e47c ea0b2a0 454e47c 7837959 454e47c 7837959 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 8083c06 454e47c ea0b2a0 454e47c 729d700 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 729d700 454e47c 729d700 454e47c ea0b2a0 454e47c 729d700 454e47c 729d700 454e47c 729d700 454e47c 729d700 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 8083c06 7837959 454e47c 1bc7e54 7837959 0dbb356 7837959 0dbb356 7837959 1bc7e54 454e47c 8083c06 729d700 8083c06 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 1bc7e54 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 1bc7e54 ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c 7837959 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 454e47c ea0b2a0 8083c06 7837959 0dbb356 7837959 0dbb356 7837959 0dbb356 7837959 ea0b2a0 454e47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 |
"""
Unified Code-Specialized Model2Vec Distillation Script.
This script provides a unified approach for creating code-specialized embeddings
using Model2Vec distillation with optional code-specific training.
Features:
- Basic distillation (default): Simple Model2Vec distillation
- Advanced training (--train flag): Additional CodeSearchNet fine-tuning
- Checkpoint support with Beam sync utilities
- Multi-teacher model processing
- Smart resume capabilities
- Hierarchical storage: base β final
Directory Structure:
- code_model2vec/base: Basic distilled models (first step)
- code_model2vec/final: Final models (copied from base or after training)
Usage:
distiller distill [--use-beam] [--train] # Basic distillation or with training
"""
import importlib.util
import json
import logging
import os
import time
from pathlib import Path
from typing import Annotated, Any
import torch
import typer
from beam import function
from sentence_transformers import SentenceTransformer
from distiller.model2vec.distill import distill
# Try to import flash_attn to check if it's available
from .beam_utils import (
BeamCheckpointManager,
create_beam_utilities,
download_model_from_beam,
sync_checkpoints_from_beam,
sync_checkpoints_to_beam,
upload_model_to_beam,
)
from .config import (
codesearchnet_config,
directories,
distillation_config,
get_distillation_function_kwargs,
get_training_function_kwargs,
get_volume_config,
languages_config,
)
# Check if flash_attn is available and compatible
FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None
# =============================================================================
# CONFIGURATION
# =============================================================================
VOLUME_CONFIG = get_volume_config()
LOCAL_BASE_DIR = directories.base
LOCAL_FINAL_DIR = directories.final
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Teacher models for distillation
DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
# =============================================================================
# FLASH ATTENTION UTILITIES
# =============================================================================
def configure_flash_attention() -> dict[str, Any]:
"""Configure flash attention settings and return model kwargs."""
model_kwargs: dict[str, Any] = {}
if not FLASH_ATTN_AVAILABLE:
logger.info("β οΈ Flash attention not available - using standard attention")
return model_kwargs
# Set environment variables for flash attention
os.environ["FLASH_ATTENTION_FORCE_USE"] = "1"
# Disable torch compile for flash attention compatibility
os.environ["TORCH_COMPILE_DISABLE"] = "1"
# Enable flash attention in transformers
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Check if we're on a compatible GPU
try:
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability()
# Flash attention requires compute capability >= 7.5 (Turing, Ampere, Ada, Hopper)
if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5):
logger.info("β
Flash attention enabled - compatible GPU detected")
model_kwargs.update(
{
"model_kwargs": {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.float16, # Flash attention works best with fp16
"use_flash_attention_2": True,
"_attn_implementation": "flash_attention_2", # Alternative key for some models
}
}
)
else:
logger.info(f"β οΈ GPU compute capability {device_capability} < 7.5 - flash attention disabled")
else:
logger.info("β οΈ No CUDA available - flash attention disabled")
except Exception as e:
logger.warning(f"β οΈ Failed to check GPU compatibility: {e} - flash attention disabled")
return model_kwargs
def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer:
"""Load a SentenceTransformer model with flash attention if available."""
flash_kwargs = configure_flash_attention()
try:
# Try loading with flash attention first
if flash_kwargs and "model_kwargs" in flash_kwargs:
logger.info(f"π Loading model with flash attention: {Path(model_path).name}")
model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs)
logger.info("β
Model loaded successfully with flash attention")
return model
except Exception as e:
logger.warning(f"β οΈ Failed to load with flash attention: {e}")
logger.info("π Falling back to standard attention")
# Fallback to standard loading
logger.info(f"π Loading model with standard attention: {Path(model_path).name}")
model = SentenceTransformer(model_path, device=device, trust_remote_code=True)
logger.info("β
Model loaded successfully with standard attention")
return model
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def get_current_config_hash(enable_training: bool) -> str:
"""Generate a hash of current configuration parameters for checkpoint validation."""
import hashlib
config_params = {
"pca_dims": distillation_config.optimal_pca_dims,
"sif_coefficient": distillation_config.sif_coefficient,
"apply_zipf": distillation_config.apply_zipf,
"enable_training": enable_training,
}
if enable_training:
# Add a simple hash of tokenlearn parameters for config validation
tokenlearn_hash = hash(
f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}"
)
config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) # Convert to float for consistency
config_str = str(sorted(config_params.items()))
return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
def check_existing_base_model(teacher_name: str) -> str | None:
"""Check if base distilled model already exists locally."""
base_dir = Path(LOCAL_BASE_DIR)
model_dir = base_dir / f"code_model2vec_{teacher_name}"
if model_dir.exists():
# Check for essential model files
has_config = (model_dir / "config.json").exists()
has_model_file = any(
[
(model_dir / "model.safetensors").exists(),
(model_dir / "model.bin").exists(),
(model_dir / "pytorch_model.bin").exists(),
]
)
if has_config and has_model_file:
logger.info(f"β
Found existing base model: {teacher_name}")
return str(model_dir)
return None
def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None:
"""Check if final model already exists locally."""
final_dir = Path(LOCAL_FINAL_DIR)
# Add suffix for trained models
model_name = f"code_model2vec_{teacher_name}"
if enable_training:
model_name += "_fine_tuned"
final_path = final_dir / model_name
if final_path.exists():
# Check for essential model files
has_config = (final_path / "config.json").exists()
has_model_file = any(
[
(final_path / "model.safetensors").exists(),
(final_path / "model.bin").exists(),
(final_path / "pytorch_model.bin").exists(),
]
)
if has_config and has_model_file:
logger.info(f"β
Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
return str(final_path)
return None
def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool:
"""Copy base model to final directory."""
import shutil
base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
# Add suffix for trained models
final_model_name = f"code_model2vec_{teacher_name}"
if enable_training:
final_model_name += "_fine_tuned"
final_path = Path(LOCAL_FINAL_DIR) / final_model_name
try:
final_path.parent.mkdir(parents=True, exist_ok=True)
if final_path.exists():
shutil.rmtree(final_path)
shutil.copytree(base_path, final_path)
logger.info(f"π Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}")
return True
except Exception:
logger.exception(f"β Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}")
return False
def sync_model_from_beam(
teacher_name: str,
target_dir: str,
use_beam_utilities: bool = False,
) -> bool:
"""Sync model from Beam volume to local directory."""
if not use_beam_utilities:
return False
try:
target_path = Path(target_dir)
target_path.mkdir(parents=True, exist_ok=True)
beam_model_name = f"{teacher_name}_model"
success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path))
if success:
logger.info(f"π₯ Synced {teacher_name} from Beam to {target_dir}")
return True
logger.warning(f"β οΈ Failed to sync {teacher_name} from Beam")
return False
except Exception as e:
logger.warning(f"Failed to sync {teacher_name} from Beam: {e}")
return False
def sync_model_to_beam(
teacher_name: str,
source_dir: str,
use_beam_utilities: bool = False,
) -> bool:
"""Sync model from local directory to Beam volume."""
if not use_beam_utilities:
return False
try:
beam_model_name = f"{teacher_name}_model"
success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir)
if success:
logger.info(f"π€ Synced {teacher_name} to Beam from {source_dir}")
return True
logger.warning(f"β οΈ Failed to sync {teacher_name} to Beam")
return False
except Exception as e:
logger.warning(f"Failed to sync {teacher_name} to Beam: {e}")
return False
# =============================================================================
# DISTILLATION FUNCTIONS
# =============================================================================
def simple_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
retry_with_cache_clear: bool = False,
) -> Any:
"""
Perform simple Model2Vec distillation without additional training.
Args:
teacher_model: Name of teacher model
output_dir: Output directory for the distilled model
pca_dims: PCA dimensions (uses config default if None)
retry_with_cache_clear: Whether this is a retry after clearing cache
Returns:
Distilled model or None if failed
"""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else ""
logger.info(f"π Simple distillation{retry_suffix}: {teacher_model} β {output_dir}")
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
# Perform distillation with optimal parameters
model = distill(
model_name=teacher_model,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
trust_remote_code=True,
)
logger.info("β
Core distillation completed successfully")
# Validate model before saving
if hasattr(model, "tokenizer") and hasattr(model, "embedding"):
vocab_size = len(model.tokenizer.get_vocab())
embedding_size = model.embedding.shape[0]
logger.info("π Model validation:")
logger.info(f" - Vocabulary size: {vocab_size}")
logger.info(f" - Embedding matrix size: {embedding_size}")
if vocab_size != embedding_size:
logger.warning(f"β οΈ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}")
logger.warning("β οΈ This may cause issues in downstream usage")
else:
logger.info("β
Vocabulary and embedding sizes match")
# Save the model
model.save_pretrained(str(output_path))
logger.info(f"πΎ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(model)}")
if hasattr(model, "embedding"):
logger.info(f"Embedding shape: {model.embedding.shape}")
logger.info(f"Embedding dtype: {model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"π Simple distillation completed in {total_time:.2f} seconds")
return model
except ValueError as e:
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
logger.warning(f"β οΈ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
logger.warning(f"Error details: {e}")
logger.warning("π‘ This model has incompatible tokenization. Skipping...")
return None
if "weight is on the meta device" in str(e):
logger.warning(f"β οΈ Device placement issue with {teacher_model} - model weights on meta device")
logger.warning(f"Error details: {e}")
logger.warning("π‘ This model has device placement issues. Skipping...")
return None
raise
except AttributeError as e:
if "backend_tokenizer" in str(e):
logger.warning(f"β οΈ Tokenizer compatibility issue with {teacher_model}")
logger.warning(f"Error details: {e}")
logger.warning("π‘ This model's tokenizer is incompatible with Model2Vec. Skipping...")
return None
raise
except FileNotFoundError as e:
if "transformers_modules" in str(e) or "xlm_padding.py" in str(e):
logger.warning(f"β οΈ Missing custom model files for {teacher_model}")
logger.warning(f"Error details: {e}")
# Try clearing cache and retrying once
if not retry_with_cache_clear:
logger.info("π§ Attempting to clear cache and retry...")
if clear_model_cache(teacher_model):
logger.info("π Retrying distillation after cache clear...")
return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True)
logger.warning("π‘ This model has missing dependencies. Manual intervention may be required.")
return None
raise
except Exception:
logger.exception(f"β Simple distillation failed for {teacher_model}")
return None
def load_optimized_dataset(
max_samples: int | None = None,
checkpoint_manager: BeamCheckpointManager | None = None,
dataset_path: str | None = None,
) -> list[str]:
"""Load our pre-created optimized dataset for tokenlearn training."""
from .dataset import DATASET_OUTPUT_DIR
from .dataset import load_optimized_dataset as load_dataset_func
# Use configuration if not provided as parameter
if dataset_path is None:
dataset_path = distillation_config.custom_dataset_path
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
# Use configuration default if not specified
if max_samples is None:
max_samples = distillation_config.tokenlearn_max_samples
logger.info(f"π― Loading optimized dataset from {dataset_dir}")
logger.info(f"π Target samples: {max_samples}")
try:
# Load the training split of our optimized dataset
df = load_dataset_func(output_dir=dataset_dir, split="train")
# Extract the text column (which contains our formatted query + code)
texts = df["text"].tolist()
# Shuffle for better training distribution
import random
random.seed(42)
random.shuffle(texts)
# Limit to max_samples
if len(texts) > max_samples:
texts = texts[:max_samples]
logger.info(f"β
Loaded {len(texts)} optimized training samples")
# Log language distribution
languages = df["language"].value_counts()
logger.info("π Language distribution:")
for lang, count in languages.items():
percentage = (count / len(df)) * 100
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
return texts
except Exception as e:
logger.warning(f"β οΈ Failed to load optimized dataset: {e}")
logger.info("π Falling back to original CodeSearchNet loading...")
return load_codesearchnet_dataset(max_samples, checkpoint_manager)
def load_codesearchnet_dataset(
max_samples: int | None = None,
checkpoint_manager: BeamCheckpointManager | None = None,
) -> list[str]:
"""Load and format the CodeSearchNet dataset for token frequency computation."""
from datasets import load_dataset
# Use configuration default if not specified
if max_samples is None:
max_samples = distillation_config.tokenlearn_max_samples
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
logger.info(f"Limiting to {max_samples} samples for training efficiency")
logger.info(f"Languages: {', '.join(languages_config.all)}")
# Check for existing dataset checkpoint
texts = []
start_from = 0
if checkpoint_manager:
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
if checkpoint_data:
cached_texts = checkpoint_data.get("data", {}).get("texts", [])
if len(cached_texts) >= max_samples:
logger.info(f"β
Resumed dataset loading: {len(cached_texts)} texts from checkpoint")
return cached_texts[:max_samples]
logger.info(f"π Partial dataset found: {len(cached_texts)} texts, continuing...")
texts = cached_texts
start_from = len(texts)
try:
# Calculate samples per language for balanced distribution
num_languages = len(languages_config.all)
samples_per_language = max_samples // num_languages
remaining_samples = max_samples % num_languages
logger.info(f"π Target distribution: {samples_per_language} samples per language")
if remaining_samples > 0:
logger.info(f"π Extra {remaining_samples} samples will be distributed to first languages")
# Load training data from each language separately for balanced distribution
language_texts: dict[str, list[str]] = {}
total_collected = len(texts)
for i, language in enumerate(languages_config.all):
if total_collected >= max_samples:
break
logger.info(f"π Loading {language} training data...")
# Determine how many samples to collect for this language
target_for_lang = samples_per_language
if i < remaining_samples: # Distribute extra samples to first languages
target_for_lang += 1
# Skip if we already have enough from this language
if language in language_texts and len(language_texts[language]) >= target_for_lang:
continue
try:
# Load training split for the specific language (same format as evaluate.py)
from datasets import load_dataset
dataset = load_dataset(
codesearchnet_config.dataset_name,
language,
split="train",
trust_remote_code=True,
)
lang_texts: list[str] = []
processed_count = 0
for processed_count, example in enumerate(dataset, 1):
if len(lang_texts) >= target_for_lang:
break
# Use same field names as evaluate.py
doc_string = example.get("func_documentation_string", "").strip()
code_string = example.get("func_code_string", "").strip()
if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50:
# Format as documentation-code pair for training (same as evaluate.py)
text = f"Documentation: {doc_string}\nCode:\n{code_string}"
# Ensure reasonable length for embedding models
if len(text) <= 2048:
lang_texts.append(text)
if processed_count % 5000 == 0:
logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}")
language_texts[language] = lang_texts
total_collected += len(lang_texts)
logger.info(f"β
{language}: collected {len(lang_texts)} samples")
except Exception as e:
logger.warning(f"β οΈ Failed to load {language} data: {e}")
continue
# Combine all language texts in a balanced way
combined_texts = []
# Add existing texts first (from checkpoint)
if start_from > 0:
combined_texts = texts[:start_from]
# Interleave texts from different languages for better training distribution
max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0
for sample_idx in range(max_lang_samples):
for language in languages_config.all:
if len(combined_texts) >= max_samples:
break
if language in language_texts and sample_idx < len(language_texts[language]):
combined_texts.append(language_texts[language][sample_idx])
if len(combined_texts) >= max_samples:
break
# Truncate to exact max_samples
combined_texts = combined_texts[:max_samples]
# Log final distribution
logger.info("π Final dataset distribution:")
lang_counts: dict[str, int] = {}
for text in combined_texts:
# Simple heuristic to identify language from code patterns
if "def " in text and ":" in text:
lang_counts["python"] = lang_counts.get("python", 0) + 1
elif "function " in text and "{" in text:
lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1
elif "public " in text and "class " in text:
lang_counts["java"] = lang_counts.get("java", 0) + 1
elif "<?php" in text or "$" in text:
lang_counts["php"] = lang_counts.get("php", 0) + 1
elif "func " in text and "end" in text:
lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1
elif "func " in text and "}" in text:
lang_counts["go"] = lang_counts.get("go", 0) + 1
else:
lang_counts["other"] = lang_counts.get("other", 0) + 1
for lang, count in lang_counts.items():
percentage = (count / len(combined_texts)) * 100
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
# Final checkpoint save
if checkpoint_manager:
checkpoint_data = {
"config_hash": get_current_config_hash(enable_training=True),
"stage": "dataset",
"step": 0,
"timestamp": time.time(),
"data": {"texts": combined_texts},
}
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet")
return combined_texts
except Exception:
logger.exception("Error loading CodeSearchNet dataset")
return texts # Return what we have so far
def generate_teacher_embeddings(
teacher_model: SentenceTransformer,
texts: list[str],
checkpoint_manager: BeamCheckpointManager | None = None,
) -> torch.Tensor:
"""Generate teacher embeddings for code training with checkpoint support."""
logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
# Check for existing embeddings checkpoint
if checkpoint_manager:
volume_path = Path(VOLUME_CONFIG.mount_path)
embeddings_path = volume_path / "embeddings_cache.pt"
config_path = volume_path / "embeddings_config.json"
if embeddings_path.exists() and config_path.exists():
try:
# Load config first to validate compatibility
with config_path.open("r") as f:
config_data = json.load(f)
current_hash = get_current_config_hash(enable_training=True)
if config_data.get("config_hash") == current_hash:
# Load the embeddings tensor
final_embeddings = torch.load(embeddings_path, map_location="cpu")
num_expected = config_data.get("num_texts", len(texts))
if final_embeddings.shape[0] >= num_expected:
logger.info(f"β
Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)")
return final_embeddings[: len(texts)]
except Exception as e:
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
# Generate embeddings from scratch
logger.info("Generating fresh teacher embeddings...")
batch_size = 16 # Fixed batch size for teacher embedding generation
embeddings_list = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
try:
batch_embeddings = teacher_model.encode(
batch_texts,
convert_to_tensor=True,
batch_size=batch_size,
show_progress_bar=False,
normalize_embeddings=True,
)
embeddings_list.append(batch_embeddings)
if i % (batch_size * 10) == 0:
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU OOM with batch size {batch_size}, reducing...")
torch.cuda.empty_cache()
batch_size = max(1, batch_size // 2)
# Retry with smaller batch size
batch_embeddings = teacher_model.encode(
batch_texts,
convert_to_tensor=True,
batch_size=batch_size,
show_progress_bar=False,
normalize_embeddings=True,
)
embeddings_list.append(batch_embeddings)
# Combine all embeddings
teacher_embeddings = torch.cat(embeddings_list, dim=0)
# Ensure fp32 precision
if teacher_embeddings.dtype != torch.float32:
teacher_embeddings = teacher_embeddings.to(torch.float32)
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
# Save embeddings cache for future runs
if checkpoint_manager:
try:
volume_path = Path(VOLUME_CONFIG.mount_path)
embeddings_path = volume_path / "embeddings_cache.pt"
config_path = volume_path / "embeddings_config.json"
# Save embeddings tensor
torch.save(teacher_embeddings, embeddings_path)
# Save configuration
config_data = {
"config_hash": get_current_config_hash(enable_training=True),
"num_texts": len(texts),
"embedding_shape": list(teacher_embeddings.shape),
"timestamp": time.time(),
}
with config_path.open("w") as f:
json.dump(config_data, f, indent=2)
logger.info("πΎ Saved embeddings cache for future runs")
except Exception as e:
logger.warning(f"Failed to save embeddings cache: {e}")
return teacher_embeddings
def tokenlearn_training(
student_model: Any,
teacher_model: SentenceTransformer,
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
) -> Any:
"""
Perform tokenlearn training following the official POTION approach.
This follows the 4-step process:
1. Model2Vec distillation (already done - student_model)
2. Sentence transformer inference (create features)
3. Tokenlearn training
"""
from pathlib import Path
logger.info("π§ͺ Starting tokenlearn training (POTION approach)...")
# Create persistent directories for tokenlearn workflow (for checkpoint preservation)
teacher_model_name = getattr(teacher_model, "model_name", None)
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
# Try to extract from the first module if it's a SentenceTransformer
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
teacher_model_name = first_module.auto_model.name_or_path
if not teacher_model_name:
teacher_model_name = "unknown_teacher"
# Use persistent directory for tokenlearn checkpoints
teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_")
persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug
features_dir = persistent_tokenlearn_dir / "features"
model_dir = persistent_tokenlearn_dir / "base_model"
trained_dir = persistent_tokenlearn_dir / "trained_model"
features_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
trained_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"π Using persistent tokenlearn directory: {persistent_tokenlearn_dir}")
# Save the base distilled model for tokenlearn
student_model.save_pretrained(str(model_dir))
logger.info(f"πΎ Saved base model to {model_dir}")
# Step 2: Create features using sentence transformer
logger.info("π Step 2: Creating features using sentence transformer...")
# Get teacher model name/path for tokenlearn
teacher_model_name = getattr(teacher_model, "model_name", None)
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
# Try to extract from the first module if it's a SentenceTransformer
# _modules is a dict-like container, get the first module by iterating
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
teacher_model_name = first_module.auto_model.name_or_path
logger.info(f"π Using teacher model: {teacher_model_name}")
# Prepare dataset for tokenlearn featurization
dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir)
# Check if featurization already completed (checkpoint detection)
featurization_complete_marker = features_dir / ".featurization_complete"
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
logger.info("β
Found existing featurization checkpoint with valid output files")
logger.info(f"π Using cached features from: {features_dir}")
# Verify marker is still valid
output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
logger.info(f"π Found {len(output_files)} cached feature files")
else:
if featurization_complete_marker.exists():
logger.warning("β οΈ Featurization marker exists but output files are missing - re-running featurization")
featurization_complete_marker.unlink()
logger.info("π No valid featurization checkpoint found - starting featurization...")
if not teacher_model_name:
logger.warning("β οΈ Could not determine teacher model name, using fallback")
teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
logger.info(f"π Using teacher model: {teacher_model_name}")
try:
# Use direct function call instead of subprocess
from datasets import load_dataset
from distiller.tokenlearn.featurize import featurize
logger.info("π Running tokenlearn featurization...")
logger.info(f"π Dataset: {dataset_path} (config: {dataset_name})")
logger.info(f"π Text field: {text_key}")
# Load the dataset
if dataset_name is None:
# For local JSON files, don't pass name parameter
dataset = load_dataset(
"json",
data_files=dataset_path,
split="train",
streaming=True,
)
else:
# For remote datasets with specific configurations
dataset = load_dataset(
dataset_path,
name=dataset_name,
split="train",
streaming=True,
)
# Call featurization function directly
featurize(
dataset=iter(dataset),
model=teacher_model,
output_dir=str(features_dir),
max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting
batch_size=512, # IMPROVEMENT: Smaller batch for better gradients
text_key=text_key,
)
logger.info("β
Featurization completed successfully")
# Create checkpoint marker to indicate featurization is complete
featurization_complete_marker.touch()
logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}")
except Exception as e:
logger.exception("π₯ Tokenlearn featurization failed")
logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed")
msg = f"Tokenlearn featurization failed: {e}"
raise RuntimeError(msg) from e
# Step 3: Train using tokenlearn-train
logger.info("π Step 3: Training using tokenlearn...")
# Check if training already completed (checkpoint detection)
training_complete_marker = trained_dir / ".training_complete"
training_fallback_marker = trained_dir / ".training_fallback"
if training_complete_marker.exists() and verify_training_output(trained_dir):
logger.info("β
Found existing training checkpoint with valid model files")
logger.info(f"π Using cached trained model from: {trained_dir}")
# Show available model files
model_files = []
for pattern in ["*.json", "*.safetensors", "*.bin"]:
model_files.extend(list(trained_dir.glob(pattern)))
for subdir in ["model", "model_weighted"]:
subdir_path = trained_dir / subdir
if subdir_path.exists():
model_files.extend(list(subdir_path.glob(pattern)))
logger.info(f"π Found {len(model_files)} cached model files")
elif training_fallback_marker.exists():
logger.warning("β οΈ Training fallback marker found - tokenlearn failed previously")
logger.info("π Proceeding with fallback to base model (simple distillation)")
# Skip training and proceed to model loading (will fallback to base model)
else:
if training_complete_marker.exists():
logger.warning("β οΈ Training marker exists but model files are missing - re-running training")
training_complete_marker.unlink()
logger.info("π No valid training checkpoint found - starting training...")
try:
# Use direct function call instead of subprocess
from distiller.tokenlearn.train import train_model
from distiller.tokenlearn.utils import collect_means_and_texts
# IMPROVED APPROACH: Try optimized parameters first
logger.info("π Attempting IMPROVED tokenlearn training with optimized parameters...")
logger.info("π Using smaller vocabulary and conservative PCA to prevent overfitting")
# Collect training data from features directory
paths = sorted(features_dir.glob("*.json"))
train_txt, train_vec = collect_means_and_texts(paths)
logger.info(f"π Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training")
try:
# Try improved parameters first
trained_model = train_model(
model_name=str(teacher_model_name),
train_txt=train_txt,
train_vec=train_vec,
device="cuda" if torch.cuda.is_available() else "cpu",
vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting
pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions
)
# Save the trained model
trained_model.save_pretrained(str(trained_dir))
logger.info("β
IMPROVED tokenlearn training completed successfully")
training_complete_marker.touch()
logger.info(f"πΎ Created improved training checkpoint: {training_complete_marker}")
except Exception as e:
logger.warning(f"β οΈ Improved training failed: {e}")
logger.info("π Falling back to CONSERVATIVE tokenlearn training...")
# FALLBACK: Ultra-conservative training approach
try:
trained_model = train_model(
model_name=str(teacher_model_name),
train_txt=train_txt,
train_vec=train_vec,
device="cuda" if torch.cuda.is_available() else "cpu",
vocab_size=15000, # FALLBACK: Even smaller vocabulary
pca_dims=128, # FALLBACK: Smaller PCA dimensions
)
# Save the trained model
trained_model.save_pretrained(str(trained_dir))
logger.info("β
Conservative tokenlearn training completed successfully")
training_complete_marker.touch()
logger.info(f"πΎ Created conservative training checkpoint: {training_complete_marker}")
except Exception as e2:
logger.exception("β Conservative tokenlearn training also failed")
logger.exception("π₯ All training approaches failed - check output above for details")
# Create training marker to indicate we tried but failed
training_fallback_marker = trained_dir / ".training_fallback"
training_fallback_marker.touch()
logger.exception("π₯ Tokenlearn training failed completely")
msg = f"All tokenlearn training approaches failed: {e2}"
raise RuntimeError(msg) from e2
except Exception as e:
logger.warning("π₯ All tokenlearn training approaches failed")
logger.exception("π₯ All training approaches failed completely - cannot proceed")
msg = f"All training approaches failed: {e}"
raise RuntimeError(msg) from e
# Step 4: Load the trained model and apply post-training re-regularization
logger.info("π¦ Step 4: Loading trained model and applying post-training re-regularization...")
# Check if we need to use fallback due to tokenlearn failure
training_fallback_marker = trained_dir / ".training_fallback"
if training_fallback_marker.exists():
logger.error("β Tokenlearn training failed previously - cannot return trained model")
logger.error("π₯ Training was requested but failed - this would be misleading to return base model")
msg = "Tokenlearn training failed - cannot proceed with training pipeline"
raise RuntimeError(msg)
try:
from distiller.model2vec.model import StaticModel
# Load the trained model from tokenlearn
trained_model_path = trained_dir / "model"
if not trained_model_path.exists():
# Try alternative paths
possible_paths = [
trained_dir / "model_weighted",
trained_dir,
]
for path in possible_paths:
if path.exists() and any(path.glob("*.json")):
trained_model_path = path
break
else:
logger.error(f"β Could not find trained model in {trained_dir}")
logger.error("π₯ Training was requested but no trained model found - cannot proceed")
msg = f"Trained model not found in {trained_dir} - training pipeline failed"
raise RuntimeError(msg)
# Load the model before re-regularization
logger.info("π Loading model from tokenlearn training...")
trained_model = StaticModel.from_pretrained(str(trained_model_path))
# Return the trained model directly
logger.info("β
Tokenlearn training pipeline completed successfully")
return trained_model
except ValueError as e:
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
logger.exception("π₯ Token-vector mismatch in tokenlearn training")
logger.exception("Error details")
logger.exception("π§ This is a known issue with tokenlearn/Model2Vec integration")
logger.exception("π₯ Training was requested but failed due to token-vector mismatch")
msg = f"Tokenlearn training failed due to token-vector mismatch: {e}"
raise RuntimeError(msg) from e
logger.exception("π₯ Failed to load tokenlearn trained model")
msg = f"Failed to load tokenlearn trained model: {e}"
raise RuntimeError(msg) from e
except Exception as e:
logger.exception("π₯ Failed to load tokenlearn trained model")
logger.exception("π₯ Cannot load trained model - training failed")
msg = f"Failed to load tokenlearn trained model: {e}"
raise RuntimeError(msg) from e
def distill_single_teacher(
teacher_model: str,
enable_training: bool = False,
use_beam_utilities: bool = False,
pca_dims: int | None = None,
) -> dict[str, Any]:
"""
Distill a single teacher model with optional training.
Args:
teacher_model: Name of teacher model
enable_training: Whether to enable advanced training
use_beam_utilities: Whether to use Beam utilities
pca_dims: PCA dimensions
Returns:
Dictionary with distillation results
"""
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
# Add suffix for trained models
final_model_name = f"code_model2vec_{teacher_name}"
if enable_training:
final_model_name += "_fine_tuned"
final_dir = Path(LOCAL_FINAL_DIR) / final_model_name
logger.info(f"\n{'=' * 60}")
logger.info(f"π Processing teacher model: {teacher_model}")
logger.info(f"π Teacher name: {teacher_name}")
logger.info(f"π Training enabled: {enable_training}")
logger.info(f"{'=' * 60}")
# Check model compatibility first
is_compatible, warning_msg = check_model_compatibility(teacher_model)
if not is_compatible:
logger.warning(f"β οΈ Known compatibility issue: {warning_msg}")
logger.info("π§ Attempting distillation anyway, but may fail...")
# Try model-specific workarounds
workaround_type = try_model_workarounds(teacher_model)
# Don't skip if we have a workaround - we'll use it later
start_time = time.time()
# Initialize Beam utilities if requested
checkpoint_mgr = None
if use_beam_utilities:
try:
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
except Exception as e:
logger.warning(f"Failed to initialize Beam utilities: {e}")
try:
# Step 1: Check for existing final model
existing_final = check_existing_final_model(teacher_name, enable_training)
if existing_final:
logger.info(f"β
Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "skipped_existing_final",
"final_path": existing_final,
"distillation_time": total_time,
}
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
if use_beam_utilities and checkpoint_mgr:
logger.info(f"π Syncing existing checkpoints for {teacher_name}...")
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints)
if enable_training:
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
# Step 2: Check for existing base model or create it
existing_base = check_existing_base_model(teacher_name)
base_model = None
if existing_base:
logger.info(f"β
Found existing base model: {teacher_name}")
if enable_training:
# Load base model for training
from distiller.model2vec.model import StaticModel
base_model = StaticModel.from_pretrained(existing_base)
elif use_beam_utilities:
synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities)
if synced:
existing_base = str(base_dir)
if enable_training:
from distiller.model2vec.model import StaticModel
base_model = StaticModel.from_pretrained(existing_base)
if not existing_base:
# Perform simple distillation to create base model
logger.info(f"π Creating base model for {teacher_name}")
# Check if we need specialized distillation
workaround_type = try_model_workarounds(teacher_model)
if workaround_type == "salesforce":
base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims)
elif workaround_type == "baai":
base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims)
else:
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
if base_model is None:
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_base_distillation",
"error": "Simple distillation failed",
"distillation_time": total_time,
}
# Sync base model and checkpoints to Beam
if use_beam_utilities:
sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities)
if checkpoint_mgr:
sync_checkpoints_to_beam(
VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints
)
existing_base = str(base_dir)
# Step 3: Handle final model creation
if enable_training and base_model is not None:
# Perform tokenlearn training (POTION approach)
logger.info(f"π§ͺ Starting tokenlearn training for {teacher_name}")
try:
# Load teacher model for training
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
# Perform tokenlearn training (POTION approach)
final_model = tokenlearn_training(
base_model,
teacher_st_model,
checkpoint_mgr,
)
# Save final model
final_dir.mkdir(parents=True, exist_ok=True)
final_model.save_pretrained(str(final_dir))
# Sync final model and training checkpoints to Beam
if use_beam_utilities:
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
if checkpoint_mgr:
sync_checkpoints_to_beam(
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
)
del teacher_st_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except RuntimeError as e:
# Training failed - clean up and return failure
logger.exception(f"β Training failed for {teacher_name}")
# Clean up teacher model if it was loaded
if "teacher_st_model" in locals():
del teacher_st_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_training",
"error": f"Training failed: {e!s}",
"base_path": existing_base, # Base model was created successfully
"distillation_time": total_time,
}
else:
# Copy base to final (no training)
logger.info(f"π Copying base to final for {teacher_name}")
if not copy_base_to_final(teacher_name, enable_training):
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_copy_to_final",
"error": "Failed to copy base to final",
"distillation_time": total_time,
}
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "success",
"enable_training": enable_training,
"base_path": existing_base,
"final_path": str(final_dir),
"distillation_time": total_time,
}
except Exception as e:
logger.exception(f"β Failed to process {teacher_model}")
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed",
"error": str(e),
"distillation_time": total_time,
}
# =============================================================================
# MAIN EXECUTION FUNCTIONS
# =============================================================================
def run_local_distillation(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Run distillation locally."""
logger.info("π₯οΈ Running distillation locally")
if teacher_models is None:
teacher_models = DEFAULT_TEACHER_MODELS
results = {}
successful_models = []
logger.info("π Starting distillation workflow")
logger.info(f"π Processing {len(teacher_models)} teacher models")
logger.info(f"π Training enabled: {enable_training}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"π Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
# Clear cache for problematic models if requested
if clear_cache:
logger.info("π§Ή Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in models_to_distill:
clear_model_cache(model)
# Clear tokenlearn checkpoints if requested (for training mode)
# Note: Checkpoint clearing is handled at the main function level
# Run distillation workflow
for teacher_model in models_to_distill:
result = distill_single_teacher(
teacher_model=teacher_model,
enable_training=enable_training,
use_beam_utilities=False,
pca_dims=pca_dims,
)
teacher_name = result["teacher_name"]
results[teacher_name] = result
if result["status"] == "success" or result["status"].startswith("skipped"):
successful_models.append(teacher_name)
elif result["status"] == "failed_training":
# Note: Training failed but base model may still be available
logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded")
# Summary
logger.info("\nπ DISTILLATION WORKFLOW COMPLETE!")
logger.info(f"π Successful models: {len(successful_models)}")
logger.info(f"π Training mode: {'Enabled' if enable_training else 'Basic distillation only'}")
for model_name in successful_models:
result = results[model_name]
logger.info(f"β
{model_name}: {result['teacher_model']}")
# Save results summary
results_summary = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": successful_models,
"all_results": results,
"total_successful": len(successful_models),
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
}
# Save results to file
results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json"
results_file.parent.mkdir(parents=True, exist_ok=True)
with results_file.open("w") as f:
json.dump(results_summary, f, indent=2)
logger.info(f"π Results summary saved to: {results_file}")
return results_summary
def _beam_distill_internal(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Shared internal implementation for beam distillation."""
if teacher_models is None:
teacher_models = DEFAULT_TEACHER_MODELS
# Clear cache for problematic models if requested
if clear_cache:
logger.info("π§Ή Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in teacher_models:
clear_model_cache(model)
results = {}
successful_models = []
logger.info("π Starting Beam distillation workflow")
logger.info(f"π Processing {len(teacher_models)} teacher models")
logger.info(f"π Training enabled: {enable_training}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"π Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
for teacher_model in models_to_distill:
result = distill_single_teacher(
teacher_model=teacher_model,
enable_training=enable_training,
use_beam_utilities=True,
pca_dims=pca_dims,
)
teacher_name = result["teacher_name"]
results[teacher_name] = result
if result["status"] == "success" or result["status"].startswith("skipped"):
successful_models.append(teacher_name)
elif result["status"] == "failed_training":
# Note: Training failed but base model may still be available
logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded")
# Summary
logger.info("\nπ BEAM DISTILLATION WORKFLOW COMPLETE!")
logger.info(f"π Successful models: {len(successful_models)}")
# Save results to Beam volume
volume_path = Path(VOLUME_CONFIG.mount_path)
results_summary = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": successful_models,
"all_results": results,
"total_successful": len(successful_models),
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
}
results_file = volume_path / "distillation_results.json"
with results_file.open("w") as f:
json.dump(results_summary, f, indent=2)
logger.info(f"π Beam results saved to: {results_file}")
return results_summary
@function(**get_training_function_kwargs())
def _beam_train_models(
teacher_models: list[str] | None = None,
enable_training: bool = True,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Beam function for training (distillation + tokenlearn)."""
logger.info("βοΈ Running training on Beam")
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
@function(**get_distillation_function_kwargs())
def _beam_distill_models(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Beam function for basic distillation only."""
logger.info("βοΈ Running distillation on Beam")
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
def run_beam_distillation(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Run distillation on Beam and sync results."""
logger.info("βοΈ Running distillation on Beam with local sync")
try:
# Choose appropriate beam function based on training flag
beam_function = _beam_train_models if enable_training else _beam_distill_models
# Run distillation on Beam
results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache)
# Check if Beam execution was successful
if not results:
logger.error("β Beam execution failed or returned no results")
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": [],
"all_results": {},
"total_successful": 0,
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
"error": "Beam execution failed",
}
# Sync models back to local directories
if results.get("successful_models"):
logger.info("π₯ Syncing models from Beam to local directories...")
for teacher_name in results["successful_models"]:
# Sync base model
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True)
# Sync final model if training was enabled
if enable_training:
final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}"
sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True)
else:
# Copy base to final
copy_base_to_final(teacher_name, enable_training)
logger.info("β
All models synced from Beam")
return results
except Exception as e:
logger.exception("β Beam distillation failed with exception")
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": [],
"all_results": {},
"total_successful": 0,
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
"error": str(e),
}
# =============================================================================
# CLI INTERFACE
# =============================================================================
def main(
use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
clear_cache: Annotated[
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
] = False,
clear_checkpoints: Annotated[
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
] = False,
use_optimized_dataset: Annotated[
bool,
typer.Option(
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
),
] = False,
dataset_path: Annotated[
str | None,
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
] = None,
) -> None:
"""Unified distillation command with optional training."""
logger.info("π Starting unified Model2Vec distillation workflow")
# Set dataset configuration
distillation_config.use_optimized_dataset = use_optimized_dataset
distillation_config.custom_dataset_path = dataset_path
if use_optimized_dataset and train:
dataset_source = dataset_path or "code_model2vec/dataset"
logger.info(f"π― Using optimized dataset from: {dataset_source}")
elif train:
logger.info("π― Using C4 dataset for training (following POTION approach)")
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"π Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
# Clear cache for problematic models if requested
if clear_cache:
logger.info("π§Ή Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in models_to_distill:
clear_model_cache(model)
# Clear tokenlearn checkpoints if requested (for training mode)
if clear_checkpoints and train:
logger.info("π§Ή Clearing tokenlearn checkpoints to force fresh featurization and training...")
for teacher_model in models_to_distill:
teacher_model.split("/")[-1].replace("-", "_")
# Use the same persistent directory structure as the training function
teacher_slug = teacher_model.replace("/", "_").replace("-", "_")
persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug
features_dir = persistent_tokenlearn_dir / "features"
trained_dir = persistent_tokenlearn_dir / "trained_model"
# Clear persistent tokenlearn checkpoints
if features_dir.exists() or trained_dir.exists():
clear_tokenlearn_checkpoints(features_dir, trained_dir)
logger.info(f"ποΈ Cleared persistent tokenlearn checkpoints for {teacher_model}")
else:
logger.info(f"βΉοΈ No tokenlearn checkpoints found for {teacher_model}")
elif clear_checkpoints and not train:
logger.warning("β οΈ --clear-checkpoints flag is only relevant when training is enabled (--train)")
# Run distillation workflow
if use_beam:
results = run_beam_distillation(
teacher_models=models_to_distill,
enable_training=train,
pca_dims=pca_dims,
clear_cache=clear_cache,
)
else:
results = run_local_distillation(
teacher_models=models_to_distill,
enable_training=train,
pca_dims=pca_dims,
clear_cache=clear_cache,
)
# Handle case where results might be None or invalid
if not results or not isinstance(results, dict):
logger.error("β Distillation workflow failed - no valid results returned")
results = {
"total_successful": 0,
"total_attempted": len(models_to_distill),
"error": "Workflow failed",
}
# Final summary
successful_count = results.get("total_successful", 0)
total_attempted = results.get("total_attempted", 0)
logger.info("\nπ UNIFIED DISTILLATION WORKFLOW COMPLETED!")
logger.info(f"π Successfully processed: {successful_count}/{total_attempted} models")
logger.info(f"π Base models saved to: {LOCAL_BASE_DIR}")
logger.info(f"π Final models saved to: {LOCAL_FINAL_DIR}")
if train:
logger.info("π Advanced training was enabled - models include CodeSearchNet specialization")
else:
logger.info("π Basic distillation only - use --train flag to enable advanced training")
def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]:
"""
Check if a model has known compatibility issues with Model2Vec.
Returns:
Tuple of (is_compatible, warning_message)
"""
known_incompatible = {
"BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute",
"jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies",
"Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors",
}
if teacher_model in known_incompatible:
return False, known_incompatible[teacher_model]
# Check for model families that might have issues
if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower():
return False, "BGE models with Qwen2 tokenizers may have compatibility issues"
if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower():
return False, "Jina embeddings v3 models may have missing dependencies"
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
return False, "Salesforce SFR embedding models may have device placement issues"
return True, None
def clear_model_cache(model_name: str) -> bool:
"""Clear HuggingFace cache for a specific model."""
try:
import shutil
from pathlib import Path
# Get HuggingFace cache directory
cache_dir = Path.home() / ".cache" / "huggingface"
# Find model-specific cache directories
model_slug = model_name.replace("/", "--")
# Clear transformers cache
transformers_cache = cache_dir / "transformers" / model_slug
if transformers_cache.exists():
shutil.rmtree(transformers_cache)
logger.info(f"ποΈ Cleared transformers cache for {model_name}")
# Clear hub cache
hub_cache = cache_dir / "hub" / f"models--{model_slug}"
if hub_cache.exists():
shutil.rmtree(hub_cache)
logger.info(f"ποΈ Cleared hub cache for {model_name}")
# Clear modules cache
modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0]
if modules_cache.exists():
shutil.rmtree(modules_cache)
logger.info(f"ποΈ Cleared modules cache for {model_name}")
return True
except Exception as e:
logger.warning(f"Failed to clear cache for {model_name}: {e}")
return False
def try_model_workarounds(teacher_model: str) -> str | None:
"""
Try specific workarounds for problematic models.
Returns:
The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available
"""
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
logger.info("π§ Salesforce SFR model detected - will use specialized distillation")
return "salesforce"
if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()):
logger.info("π§ BAAI BGE model detected - will use specialized distillation")
return "baai"
return None
def salesforce_model_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
) -> Any:
"""Special distillation function for Salesforce SFR models that handles device placement issues."""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"π Salesforce-specific distillation: {teacher_model} β {output_dir}")
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
import torch
from transformers import AutoModel, AutoTokenizer
# Enhanced custom model loading for Salesforce models
logger.info("π§ Loading model with enhanced device settings...")
# Method 1: Try with to_empty() for meta tensor handling
try:
logger.info("π Attempting with to_empty() method...")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
# Load model with meta device initially
model = AutoModel.from_pretrained(
teacher_model,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="meta", # Load on meta device first
)
# Move from meta to actual device using to_empty()
if torch.cuda.is_available():
device = torch.device("cuda")
# Create empty tensors on target device and copy weights
model = model.to_empty(device=device)
else:
device = torch.device("cpu")
model = model.to_empty(device=device)
# Ensure model is in the right dtype
model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32)
logger.info("β
Successfully loaded with to_empty() method")
except Exception as e:
logger.warning(f"to_empty() method failed: {e}")
# Method 2: Try SentenceTransformer with specific settings
logger.info("π Falling back to SentenceTransformer method...")
sentence_model = load_model_with_flash_attention(
teacher_model,
device="cpu", # Force CPU loading first
)
# Move to GPU if available
if torch.cuda.is_available():
sentence_model = sentence_model.to("cuda")
# Extract components
model = sentence_model[0].auto_model
tokenizer = sentence_model.tokenizer
logger.info("β
Successfully loaded with SentenceTransformer method")
# Now use Model2Vec's distill_from_model function directly
from distiller.model2vec.distill.distillation import distill_from_model
distilled_model = distill_from_model(
model=model,
tokenizer=tokenizer,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
)
logger.info("β
Core distillation completed successfully")
# Save the model
distilled_model.save_pretrained(str(output_path))
logger.info(f"πΎ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(distilled_model)}")
if hasattr(distilled_model, "embedding"):
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"π Salesforce distillation completed in {total_time:.2f} seconds")
# Clean up
if "sentence_model" in locals():
del sentence_model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return distilled_model
except Exception:
logger.exception(f"β Salesforce-specific distillation failed for {teacher_model}")
return None
def baai_bge_model_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
) -> Any:
"""Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues."""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"π BAAI BGE-specific distillation: {teacher_model} β {output_dir}")
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
import torch
from transformers import AutoModel, AutoTokenizer
logger.info("π§ Loading BAAI model with tokenizer workaround...")
# Try multiple approaches for BAAI models
success = False
# Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
try:
logger.info("π Attempting with SentenceTransformer wrapper...")
sentence_model = load_model_with_flash_attention(teacher_model)
# Extract components
model = sentence_model[0].auto_model
tokenizer = sentence_model.tokenizer
# Test if tokenizer works by encoding a simple text
test_encoding = tokenizer.encode("test", return_tensors="pt")
logger.info("β
SentenceTransformer method successful")
success = True
except Exception as e:
logger.warning(f"SentenceTransformer method failed: {e}")
# Method 2: Try direct loading with tokenizer replacement
try:
logger.info("π Attempting with tokenizer replacement...")
from transformers import BertTokenizerFast
# Load model directly
model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True)
# Try to use a compatible tokenizer instead
try:
# First try the original tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
except Exception:
# Fall back to BERT tokenizer for BGE models
logger.info("π Falling back to BERT tokenizer...")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
logger.info("β
Tokenizer replacement method successful")
success = True
except Exception as e2:
logger.warning(f"Tokenizer replacement method failed: {e2}")
if not success:
logger.error("β All BAAI model loading methods failed")
return None
# Now use Model2Vec's distill_from_model function directly
from distiller.model2vec.distill.distillation import distill_from_model
distilled_model = distill_from_model(
model=model,
tokenizer=tokenizer,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
)
logger.info("β
Core distillation completed successfully")
# Save the model
distilled_model.save_pretrained(str(output_path))
logger.info(f"πΎ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(distilled_model)}")
if hasattr(distilled_model, "embedding"):
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"π BAAI BGE distillation completed in {total_time:.2f} seconds")
# Clean up
if "sentence_model" in locals():
del sentence_model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return distilled_model
except Exception:
logger.exception(f"β BAAI BGE-specific distillation failed for {teacher_model}")
return None
def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None:
"""Clear tokenlearn checkpoint markers to force re-execution of steps."""
featurization_marker = features_dir / ".featurization_complete"
training_marker = trained_dir / ".training_complete"
if featurization_marker.exists():
featurization_marker.unlink()
logger.info(f"ποΈ Cleared featurization checkpoint: {featurization_marker}")
if training_marker.exists():
training_marker.unlink()
logger.info(f"ποΈ Cleared training checkpoint: {training_marker}")
def verify_featurization_output(features_dir: Path) -> bool:
"""Verify that featurization output files actually exist and are valid."""
if not features_dir.exists():
return False
# Check for expected tokenlearn output files
# Check if any expected files exist
return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"])
def verify_training_output(trained_dir: Path) -> bool:
"""Verify that training output files actually exist and are valid."""
if not trained_dir.exists():
return False
# Check for model files
model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"]
for model_file in model_files:
if (trained_dir / model_file).exists():
return True
# Check for alternative model directory structure
for subdir in ["model", "model_weighted"]:
subdir_path = trained_dir / subdir
if subdir_path.exists():
for model_file in model_files:
if (subdir_path / model_file).exists():
return True
return False
def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
"""
Prepare dataset for tokenlearn featurization.
Returns:
Tuple of (dataset_path, dataset_name, text_key) for tokenlearn
"""
if distillation_config.use_optimized_dataset:
return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir)
return _prepare_original_dataset_for_tokenlearn()
def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
"""Prepare custom optimized dataset for tokenlearn featurization."""
logger.info("π― Preparing custom optimized dataset for tokenlearn...")
# Import the dataset module
from .dataset import create_optimized_dataset, load_optimized_dataset
# Define paths
custom_dataset_dir = (
Path(distillation_config.custom_dataset_path)
if distillation_config.custom_dataset_path
else Path("code_model2vec/dataset")
)
tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset"
# Check if we need to create the custom dataset
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
logger.info("π Custom dataset not found - creating optimized dataset...")
create_optimized_dataset(
max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, # Divide by number of languages
output_dir=custom_dataset_dir,
create_multiple_formats=False, # Use simple format for tokenlearn
)
# Load the custom dataset
logger.info(f"π Loading custom dataset from {custom_dataset_dir}")
train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train")
# Prepare dataset for tokenlearn (save as JSON files that load_dataset can read)
tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True)
# Save as JSON file that tokenlearn can load with load_dataset()
train_json_path = tokenlearn_dataset_dir / "train.json"
# Create JSON lines format
import json
with train_json_path.open("w") as f:
for text in train_df["text"]:
json.dump({"text": text}, f)
f.write("\n")
logger.info(f"β
Prepared custom dataset with {len(train_df)} samples for tokenlearn")
logger.info(f"πΎ Saved JSON dataset to {train_json_path}")
# Return the JSON file path directly (not directory) and no config name for JSON loading
return str(train_json_path), None, "text"
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]:
"""Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach)."""
logger.info("π Using C4 dataset for tokenlearn (following POTION approach)...")
return (
str(distillation_config.tokenlearn_dataset), # "allenai/c4"
str(distillation_config.tokenlearn_dataset_name), # "en"
str(distillation_config.tokenlearn_text_key), # "text"
)
if __name__ == "__main__":
typer.run(main)
|