Telegram Group & Telegram Channel
Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264



group-telegram.com/neural_cell/264
Create:
Last Update:

Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264

BY the last neural cell


Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260

Share with your friend now:
group-telegram.com/neural_cell/264

View MORE
Open in Telegram


Telegram | DID YOU KNOW?

Date: |

The message was not authentic, with the real Zelenskiy soon denying the claim on his official Telegram channel, but the incident highlighted a major problem: disinformation quickly spreads unchecked on the encrypted app. Ukrainian forces have since put up a strong resistance to the Russian troops amid the war that has left hundreds of Ukrainian civilians, including children, dead, according to the United Nations. Ukrainian and international officials have accused Russia of targeting civilian populations with shelling and bombardments. The regulator said it has been undertaking several campaigns to educate the investors to be vigilant while taking investment decisions based on stock tips. Ukrainian President Volodymyr Zelensky said in a video message on Tuesday that Ukrainian forces "destroy the invaders wherever we can." In addition, Telegram now supports the use of third-party streaming tools like OBS Studio and XSplit to broadcast live video, allowing users to add overlays and multi-screen layouts for a more professional look.
from cn


Telegram the last neural cell
FROM American