group-telegram.com/neural_cell/264
Last Update:
Разбираемся в генеративных моделях: Flow matching
Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?
Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.
В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!
Как это работает:
1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки
И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:
def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)
true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x
А генерация ещё проще - просто идём маленькими шажками в нужном направлении:
def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x
А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!
Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.
В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.
Главные преимущества по сравнению с DDPM:
- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут
Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.
Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.
Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264
BY the last neural cell
Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260
Share with your friend now:
group-telegram.com/neural_cell/264