Telegram Group & Telegram Channel
Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264



group-telegram.com/neural_cell/264
Create:
Last Update:

Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264

BY the last neural cell


Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260

Share with your friend now:
group-telegram.com/neural_cell/264

View MORE
Open in Telegram


Telegram | DID YOU KNOW?

Date: |

Recently, Durav wrote on his Telegram channel that users' right to privacy, in light of the war in Ukraine, is "sacred, now more than ever." The Securities and Exchange Board of India (Sebi) had carried out a similar exercise in 2017 in a matter related to circulation of messages through WhatsApp. Lastly, the web previews of t.me links have been given a new look, adding chat backgrounds and design elements from the fully-features Telegram Web client. One thing that Telegram now offers to all users is the ability to “disappear” messages or set remote deletion deadlines. That enables users to have much more control over how long people can access what you’re sending them. Given that Russian law enforcement officials are reportedly (via Insider) stopping people in the street and demanding to read their text messages, this could be vital to protect individuals from reprisals. Two days after Russia invaded Ukraine, an account on the Telegram messaging platform posing as President Volodymyr Zelenskiy urged his armed forces to surrender.
from hk


Telegram the last neural cell
FROM American