Niklas Muennighoff, Zitong Yang, Weijia Shi, Xiang Lisa Li, Li Fei-Fei, Hannaneh Hajishirzi, Luke Zettlemoyer, Percy Liang, Emmanuel Candès, Tatsunori Hashimoto
Статья: https://arxiv.org/abs/2501.19393
Репа: https://github.com/simplescaling/s1
Продолжая тему про Test-time compute (https://t.me/gonzo_ML/3175). SFT на маленьком датасете в 1k примеров + простые стратегии чтобы заставить модель подумать подольше — и вуаля, бьём o1-preview на MATH and AIME24!
DeepSeek-R1 (https://t.me/gonzo_ML/3319) безусловно был достижением в обучении модели с ризонингом через RL, и в этом смысле открытый мир приблизился или даже догнал достижения OpenAI в лице o1, но эта работа не целилась в демонстрацию test-time scaling behaviour. В текущей работе совмещают и сильный ризонинг (но без RL), и скейлинг.
Рецепт простой и дешёвый.
1) Собираем датасет s1K из 1000 тщательно отобранных пар вопросов и ответов с reasoning traces, дистиллированных из Gemini Thinking Experimental.
2) Делаем на этом датасете SFT (PyTorch FSDP, 5 эпох, BF16) открытой модели (Qwen2.5- 32B-Instruct), в течение 26 минут на 16 H100 GPUs (это может стоить примерно от $25). Получаем модель s1-32B.
3) После обучения контролируем количество test-time compute через прямолинейную стратегию budget forcing, которая заключается в том, что 1) если модель нагенерировала уже слишком много токенов, то принудительно заканчиваем процесс добавлением end-of-thinking токена, а 2) если модель нагенерировала мало, то наоборот не даём её добавить этот токен окончания размышлений, а добавляем “Wait” и стимулируем модель подумать ещё.
4) Профит! График демонстрирует test-time scaling. Да ещё и получаем модель на Парето-фронте по sample efficiency, лучше o1-preview.
Модель полностью открыта: веса, данные и код.
Теперь чуть подробнее про отдельные шаги.
📔Датасет
Датасет собирался в два этапа. Сначала собрали 59K (59,029) вопросов из 16 источников.
Взяли задачи из имеющихся датасетов (NuminaMATH, AIME, OlympicArena, OmniMath, AGIEval) и создали пару своих, s1-prob с вопросами по теории вероятности (182 штуки), и s1-teasers с головоломками (23 штуки). Вообще не очень привычно слышать про датасеты из 23 примеров…
Для каждого вопроса сгенерировали reasoning trace и решение через Google Gemini Flash Thinking API, получили 59K триплетов <вопрос, размышление, решение>. Почистили этот набор против имеющихся вопросов для evaluation через 8-граммы и дедупликацию.
Можно было сразу на этих 59k и обучать, но хотели найти минимальную конфигурацию, поэтому устроили три стадии фильтрации.
1) Качество (Quality) должно быть высоким. Смотрим на сэмплы и убираем проблемные примеры, например, с плохим форматированием. Уменьшили датасет до 51,581 примеров и из них отобрали 384 в финальный 1k датасет.
2) Сложность (Difficulty) должна быть на уровне и требовать ризонинга. Выкинули слишком лёгкие примеры, которые решил Qwen2.5-7B-Instruct или Qwen2.5-32B-Instruct. Корректность оценивал Claude 3.5 Sonnet относительно референса. Сложность оценивали по длине цепочки рассуждений, подразумевая, что более сложный вопрос требует более длинной цепочки. Сократили датасет до 24,496 примеров.
3) Разнообразие (Diversity): датасеты должны быть из разных областей и с разными задачами. Sonnet классифицировал задачи по Mathematics Subject Classification (MSC), далее выбирали случайно один из доменов, из него сэмплили задачу так, чтобы предпочитались более длинные рассуждения, и повторяли пока не наберётся 1000 примеров. Итого набрали 50 разных доменов.
⚒️Test-time scaling
Авторы разделяют методы скейлинга на последовательные (результат зависит от предыдущих размышлений) и параллельные (типа голосования большинства). Сфокусировались на последовательных, потому что по мнению авторов они должны скейлиться лучше.