From 2fda3d4e78eec14913aa71f5d009aba9bb246fba Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Fri, 23 Jun 2023 19:25:54 -0700 Subject: [PATCH] Initial commit --- .gitignore | 305 +++++ README.md | 1 + config/__pycache__/base.cpython-310.pyc | Bin 0 -> 1243 bytes config/base.py | 56 + .../__pycache__/prompts.cpython-310.pyc | Bin 0 -> 1865 bytes .../__pycache__/rewards.cpython-310.pyc | Bin 0 -> 1549 bytes .../__pycache__/stat_tracking.cpython-310.pyc | Bin 0 -> 1452 bytes ddpo_pytorch/assets/imagenet_classes.txt | 1000 +++++++++++++++++ .../ddim_with_logprob.cpython-310.pyc | Bin 0 -> 4107 bytes .../pipeline_with_logprob.cpython-310.pyc | Bin 0 -> 8940 bytes .../diffusers_patch/ddim_with_logprob.py | 143 +++ .../diffusers_patch/pipeline_with_logprob.py | 225 ++++ ddpo_pytorch/prompts.py | 54 + ddpo_pytorch/rewards.py | 29 + ddpo_pytorch/stat_tracking.py | 34 + scripts/train.py | 341 ++++++ setup.py | 10 + 17 files changed, 2198 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 config/__pycache__/base.cpython-310.pyc create mode 100644 config/base.py create mode 100644 ddpo_pytorch/__pycache__/prompts.cpython-310.pyc create mode 100644 ddpo_pytorch/__pycache__/rewards.cpython-310.pyc create mode 100644 ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc create mode 100644 ddpo_pytorch/assets/imagenet_classes.txt create mode 100644 ddpo_pytorch/diffusers_patch/__pycache__/ddim_with_logprob.cpython-310.pyc create mode 100644 ddpo_pytorch/diffusers_patch/__pycache__/pipeline_with_logprob.cpython-310.pyc create mode 100644 ddpo_pytorch/diffusers_patch/ddim_with_logprob.py create mode 100644 ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py create mode 100644 ddpo_pytorch/prompts.py create mode 100644 ddpo_pytorch/rewards.py create mode 100644 ddpo_pytorch/stat_tracking.py create mode 100644 scripts/train.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c4f2c1c --- /dev/null +++ b/.gitignore @@ -0,0 +1,305 @@ +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim + +### Intellij+all ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Intellij+all Patch ### +# Ignore everything but code style settings and run configurations +# that are supposed to be shared within teams. + +.idea/* + +!.idea/codeStyles +!.idea/runConfigurations + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim + diff --git a/README.md b/README.md new file mode 100644 index 0000000..9040f06 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# ddpo-pytorch \ No newline at end of file diff --git a/config/__pycache__/base.cpython-310.pyc b/config/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b8f958d8dcbd70508315dceab286acf5866763a GIT binary patch literal 1243 zcmZWo%WoVt9QHhRJ3Gm4Hti-&AC$ImAfV8ur4d4?TPZv)s6D``7t1^LOx$`NvS*TR zx{-RQ#07EW3`Z{fCBPluC|p(`!3~KMpJyAi2#@CTH~#%T`}gpa{eDa^$_HQZUps{S z;y#v#kB@tJ%?Q8=_Y`?ZxX**<Yn7mvX-J826*`dr?CJIqg zR)8vFM~^Cz(WJ~PEsT*9sboE!KXdvaVEc7D$QA5LiKIq^#>V?)k;?1?nbbBkLU7wf zhFoeJ7fnt@RVD|-6?58)LcFND|^sUUH#qq{ZRFvC~>? zJYjWmK#e>Ww%-^*-^tD2&Y$z#;Uu1uOVg|V$=FP8=S*%bA5XuC<< zOTCF!>kAIv0DMdIJDC0+EUY&s%!s$Esf{ySoW+{^-OVR}$nOSSZ}ALTw!qW1|Ih!i sJ4F!tkr#QPulKOwwy)PPwY&{u3GTpPqxW%R8YS+!AH=w89g$W4AI-3D^#A|> literal 0 HcmV?d00001 diff --git a/config/base.py b/config/base.py new file mode 100644 index 0000000..e38e6bb --- /dev/null +++ b/config/base.py @@ -0,0 +1,56 @@ +import ml_collections + +def get_config(): + + config = ml_collections.ConfigDict() + + # misc + config.seed = 42 + config.logdir = "logs" + config.num_epochs = 100 + config.mixed_precision = "fp16" + config.allow_tf32 = True + + # pretrained model initialization + config.pretrained = pretrained = ml_collections.ConfigDict() + pretrained.model = "runwayml/stable-diffusion-v1-5" + pretrained.revision = "main" + + # training + config.train = train = ml_collections.ConfigDict() + train.mixed_precision = "fp16" + train.batch_size = 1 + train.use_8bit_adam = False + train.scale_lr = False + train.learning_rate = 1e-4 + train.adam_beta1 = 0.9 + train.adam_beta2 = 0.999 + train.adam_weight_decay = 1e-2 + train.adam_epsilon = 1e-8 + train.gradient_accumulation_steps = 1 + train.max_grad_norm = 1.0 + train.num_inner_epochs = 1 + train.cfg = True + train.adv_clip_max = 10 + train.clip_range = 1e-4 + + # sampling + config.sample = sample = ml_collections.ConfigDict() + sample.num_steps = 5 + sample.eta = 1.0 + sample.guidance_scale = 5.0 + sample.batch_size = 1 + sample.num_batches_per_epoch = 4 + + # prompting + config.prompt_fn = "imagenet_animals" + config.prompt_fn_kwargs = {} + + # rewards + config.reward_fn = "jpeg_compressibility" + + config.per_prompt_stat_tracking = ml_collections.ConfigDict() + config.per_prompt_stat_tracking.buffer_size = 128 + config.per_prompt_stat_tracking.min_count = 16 + + return config \ No newline at end of file diff --git a/ddpo_pytorch/__pycache__/prompts.cpython-310.pyc b/ddpo_pytorch/__pycache__/prompts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..440ec872dafeca05f3feb06a0cc0a4be3d105f8c GIT binary patch literal 1865 zcmaJ>%WoS+7@wJac)f`eCk>CZ6f9IBd>|@DjZcV6nluGuo{|gfLxpJZB9*{r+zHc1INh>jTLgi_R5!4l@Amv~P4i7)&~FeOPS=%mKIDNQ259@V+egEMkU zlbDCRc1DsV9`QQJWghb-kPW`f8z5Ks3SR}e%A0%*WRtT4(q4Z7=%h_;T?s9RDieBt zbdB?Y>6Mf?te8v7H5Qn+8xBnUl1P)dJWCK{2ROG90OyOEGB|oyq;J|aV zJ$dfayc7zf>)2nnw|lZLwhxOm`(_&*`{>*;H<)|5Jy5bgFnVinV&k16 z*CvDUZUeE=+YK^$w~#6C;O%;Ar;u4%=v`3eTid=}x^r-F@4-Rmi#reQ+vt(Z%Ry>- z)|UfOVt`ZbPJ^L-nf8UHM+-Yu4Nz>nz`*L!A+RQE(p&$kRZufDW!_K^fLs!;gMvvg zjMGy{MMb}3n~;#49ATQ-{%9lbr(IDB)5!`sxUtp(3OTsK`i_#V(4l_T#~3&LHL!EH7Vc1R z2wWQb;7(A1JL3Rc4XXgOu=3H-Dnu8@HO^4u-ZPK;T~_&3&}BUMg^i=CRz zX{nEe((RgyuDS^@$!a<59}3kmvU4mI*Dlz$hOa6wjf=Vs2gOjOMZ2!vgClAk*)?P> zU^blHkoXwdWNEgyg9A9|z9k#=g-Lm8(pS?FWw6cUW%mQPN#idN(RCU_U!k$bDEo&6 z$iu&*SbYe~?KNA^`va*=ksm62PquzEEHfizF}FQ0j|!0)8;Y`vpDFrj1FRUWFq<~W z(hP!$Uy^(I?m;SiF$$~N)qq30WG{&-UeaVM~*u_u4J&0wn+I1FF--~G!MEfndTj#j^p~Gx#tHcq zjjN4;@iB;g2uc!3b5hU&r8s6e8!*@Nf#-TL5Ryypj0}7!qz`={*%J~4EAS=}wc?MS(OpoIN_IwOq#@wS4->$3(E=FxAWuN_&!7w}GxUVbnR=%I*KM-k zOS14JlXS@z0_OZBT?FtiT)Lo(7R&)kDVE-hNM8nL5aSVvD$&bF(7&x%%~v?+R!Eak)1-ka}qIyYi(F?!L66boqKOB{p#$}9u5cD9PBe)F%N6=prT#_|C2ORlil?^_b zm{cEp1%z9$+|uDCH`ReZ3NiLT9L{ZGy)Pbr_W5cEe;|ZEs#Fmh8=Q_~MEmbxZh?aX z6V4#h3$lcp{nlTABAWdM2L|`nGijc0u=q#g(I{2d0ovLCx<#TJww>v$)Fv*I)C!k; z>wle=x>8oCYFtVyk~}VI%S>g(I4jLV%aeL+ML90uL|R70rLHTTT0YHc>*@GKYFSxZ zu|E62`tIDd4QjX$wb3>#QWMMA#5P#N)CwzHy#~q(olR5yH{h1T35E##Ca5mGO}FR` z_-?|7{q+6+aRXe;bE+cT>*5x=ld?P4`{OE&mm;KIM|}i?1mvP&ePFeqXsiiC=Hz?w zl8mT(Aj0b(2rMHE`YO7;3FH}0Cf*Nor z3+rJf&*>m=@JHpEmb-h%@(mCNe~Y&1^j){;ziK3+j-$GRySNP9#SzA{+t+y`XN_;2 zPrJX!+xVFWn=c@b;^>QN4%ibyS!}xAMv%xXs)?;a*Y6ah~pDa@||oh3vbJ HcVGDvLF!Wi literal 0 HcmV?d00001 diff --git a/ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc b/ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e67083e4b85eaef7be2ac0afa12673871a12a7 GIT binary patch literal 1452 zcmaJ>y>A>v6rYdXxxMun8OKV95`u&f(iWT%1=Yz4A`+byWU3R;Q%UD0_<1~05b7A54Cw{+IsE|T<*Rxc`?(b!0W9C z>WKU$OnVbVM>`@($3W7OfMkjQ7TwRd-*~(@Jk&=d3z5BR=_`6EY+fU}Ft~JGzV92$BDAM62%jRf2KlO=}Wm7sS z1)x|If(n|_<@?vYYrOV;8F)@u>#RPi0jC-5wPIGmVH!eTY){EX{`cA*fwd1=OXd4Am-wWgbt~OV>(;J}MyC6~^ zGNB`o+jNWWuoUDL+o3mU#FlriCAdc@D4bpo)Uk#Gy*r@v9RVZ&%@^bWIm71S`)?e= z1NhDZn05>T3;ofJZ`q2jpaOs&ihRKu>cjDlok9Qp$iHXD44Wb2=JK6B4o#3nW&}vD z_stfV&=YL@si!kfPu@;QI#Ffe3h2(g5AY(GFF^#m`%1i8es*mFj@1{Fnu!V@3V&m5uTNh5#8qh0+}*WN?6N5AKPyRK=U(qR3wg~9jmO@51pYg`YsP(RQ^mQtBKStcCth`Uf|E?f;*QsbmmeVI}6!{6lo?S&niu*b{lOA-v{_6 zV?)zYE!R^0xsfzK(0--o+L6KaOp}AO#*ODXW(>@kM|E1cq3vjA7J7~IW^No=7m#-v z7`N8OEw$d}wcNNc&RS{b+|DgtdrBG_=QM9!kj~ac?X1nsyak;djLoAiuRk^9&+=N{ z;f*aV*SBFWZ&Cj>?aa>gvu^axLb3_T5HmKUl|*hUX|)s5-DtRwv>^Grg-mT*Ti1=< zr?#c69JMZWA1*}JXB{PtJ&h7IJ)!*y@QOIo~ zdK;2kKoAk1^>Q=sQ7qfqSwFAKzr4zGVA4k1ecstZ{0fOK_6+j2>Lbp403SE~CXf1rN^6n?**=D|x_w!8^WixM{U%Sxq4OoG#4%*V{ zHt@!|?d1IneMj{$y3RXjZ=|msy?z0_E_7Nu+se0&zLvg`ejQ$K9kJ)g8Q%Lu$9yy2 zno#BR-QP4ct^| z*d&WQszH{WC@m(IjngiN!Ucjp;``vD~Dj8PRSi)|Hagwrl zg1S{w@P)jDTon6>7>?RX+xO)(xvzQ@%6iqY+v7Br-fZ+<5WDFURsQ4Qea3w$5N&Cv z^rG@$>1e5FwA8zd|Mu9AQqtmOQi>>rw!}k0BhUpWFu(M7rOTVSezelHfcUGuTUw-O zDjR;3q;BL1piv#rp`ZhRWnj$dR95B|_^u3sn}EfHP$17wgg>1TvL%pg5=-WCUI9z~ zuaptXtu(!g;Gq66jsw!}PuTd0%*1$?d2tA873`)L11mbZn{2#(LgQ5qENJsm_^T6- zr9bsS?u+P0Y+q0cx$=}93eSaG$hay!z9h{2ge|mFW`6Y6-CcI~V|FqV(IsnRBKnaR zM-zWKq74ha9+QGHq!D|cKr7V|{xdhh7krQO-Thy2p2S(eFPZaVKk|Go*omLcRD#D5 zMG+rkk~1jwBcR7{cd_xSa35oL5Wt0F0l!ivCKKSiI@KBG$z^%iB}p^$93 zu#!D?rB6xbsPcMM@lYxXms9pEbPx{Z7D^=IkV+Fn&0}%AG=`iw5tmVLDcLAFg3~H^ zrQ;PE{ZyS2doshJ;6H$wWPwWIP^7b%CoIXl8Q{BQ=E12Kh=&ymjo9s3n$DB^dwZ_@ zg?~JX<#f+IO!nS>`|h35oqPB0eS5gvUS8IYBFI&i$Lh7A?3Mx)Xi5(c01XTX2!wX& zQBl4$kUVaZl-%_r^$=L2uaCyROsa=TGFo(@+R0KsNJd|wZu~I=>SJYkyyC(COA(ds zTDU6lsg!qFOd*|sb(ve)8nK_QSI`Otx5Qyr9pl88AaN0hkhl^Dwsyyq53&V~?>}?* z%hW_>`O!pxTtRq5KkA@w?V~(y?XuH2W063Cpp<16v0I5qor-#sTg=6MAxmCVM?zOx zg{ai=x_UuP(S%w*h@&Y6`2795(#D^{K6Vy7g3IdJ^;O?VPv?S7{bO84D9vZzVqu2B z4nfJNTDUx3C`5{nOM*v_PG()>_yt~e+Oy<=8&S2+kqGP%*jA-PM2vEW+3K-4Vvle) z0r${-ibyhCvq&r;3erHFYo(=(t?e8GJNFdKMZk?mcp|r!byzew?sX53WKMD6ja=h? zIQ&M@_EE9rT_c+W$QJm*SssU@j-zP*#0@g>u9UGX8YN+tf~mwsw>piY=LU07q~m2F zh>aHpnCPlfJ0UiA)bet*f%gJ}D(dRiF6v-tDT@A=4xzBqnH0$^4mb|D*l?D^>V4f= zBk!V)atr6h#@d_%Dx8X<>r9b46<5hR`BhKA=0aYhJ}^pMZ&uWTc#2;EKbL9ILM&z0 z7G?-%3uEr1C)@~5m%lnn-DX61ZcBR7OZsIcJJe8)$u2cLYOYYzM^kjw!BuHo*d;K8 znUfY9>m$m~`3HNmI23z_f$Kfp<9r^!H9u9Ss0Pk*pcojB zM`!M)-VBR==+OP=Kq2WH%}2K(kVVgU8 zL$`HPzkzlKtzq;aV>N8u!YuV|^i5;O*r{QkQR=Yj2G$xc?b_?6r5i7A7{(XI%bkX0 z$cHdCG$s8^dvAY#I4Bw@7UZ0yB9=93Ols~>a}y1IlU4za6k(T1_2z1JNkx~)odTA4 zjB91hb=KaiMGeo*qORU?yJ>YrtLv@FIrFLVij<(einAkrjEKP4f(&<2s*PRhrtW!e)4ym+*Hs{yWD@ zti;Q#jQg=S@y>kw&I0Ba<$QL39<(q<)bP#{o@e(LaJ4+RS{Yoe;tDIXGklSq#kIyu z_=opfDW(xWm^%j!&SNeGKGpcy8ZS4munWxiQTmm-yTr67N@e|Xl5!;zrtUaLVeYPJ z+vc9lL;arB@I(E(mgk4LM_u1?9n%iezvAFxr=o?+Pkegsp4DvjJUn}Bb-8UhJS++B zHB6h=9M|%At(r9zmv$^;jDpc@vi7~5F^L*Z|9a?_7X1lGfaQ8mH zj_JQ4wiIRU3pwf!)CXfeMX&O5zM&2pMtaHqi~ zX)l~<^()eUfUj={5uT;6-cz3c)~{>* zwf==*QT(aDe7r*Mv;2Z`qB?3&=o>`qcr`GB3s`~f6uO4&O02-9UZII3@#p@ zVfhmk^w#_D1nWUMz__9;D9cKp-uNsiv*Lj+KJ(887cp06)2~zgi#~BSjVBj^LQ}Vt zVD*pH-ytf6{-xmT@pXN)#i}uO173DztvFeqBHKr{q!I=a7$Ew(f@1zWt z0EzSto}NEmJ5fMqmNXQ-7tD>!&jxG6!K;+2eDX&>0Oab5uF##6-&U~7w+_Jt{rfed zGR+P@0kFYFru^nkzW{>8HjLBPF%kbOW z=*H+~G+klDh=32O1^N9Qc<3#8Mbsj2FgEI*FX{%~uM>zq!`<`xRpUK3#?tpH!%yh5 zr?iX&HTovH@(kZK$n&}6!Xk(oUEX-uG#uAA7;gf4*;T`JxY0J{XI<0tjQSW-)vE(0 zX+_Hc^qOq=Fv3!BZaMxKt5<0SvuFE~G68V4=WxH;aGj>rs(NOp3&8iRPr31)A*WxM zxF@u&HO&UV^OViTt62`D`;yz9;Wp(iWAQ=_a=%0NB*g!JaThK2^D}o!;Q2;YZ)}|+ z77*Iwf`q|1g-|j<3Xxvb_?|2H16^=@r$n&PGPqUws)rn%>BUcbKeL3R83jSp{DhWwIY z*Tk?Qh-upo-f{ zb--`DPb|tE%8{6{cLcBSeNyEFZ^rGh1-I)}IQT2~kqQBjw^z z8~Qb29Sz%s|G^q8hky+BK=9@Bmwti;^{ZnmCgu~_hcRLku}MhGiEU{p=82RCcJY6K zc7$bbhj?Rqt50wlCuXpR82OBcEeAdo13s;eToy_i$bQ4B^6FK?ieIyullmju8e7xm zfL*A=G3wj(5xYD!#r#Hu*q(4D3KBzrkoY)m53tubb_{ieH1_B4+pX#(8r^jTlFhE` zFu6w86lRC}SU?i|^FOPQ|B`u8a&L6WS8f?!>Y!lkf&t+TH#-uyDnpFu9!0bXq?p_AHa%G;5(KB7Oci_w{b->*z0`)=chn*SBwSl2{U>u__YB;$RI2ah!GVxSB)r|Goy)Q z1iOA<_JGf@XNI+fwbvueyh?5gP4qlC4!fLt&5QQg0`63COa1WMBlLeL^nW^elKuEFMDdNW#b;d9_u2 zFP4XKD63zPNCYz`grSfo)*E=z&cl%SjfF>-{)7e64IIUXO(|w%S((Hqruslqg1ymH^{p z&+0!+3NNo3;VAkial@(lH4UgovguNP4jY!`4T;Mf(^A(qDRA{rvkWae>BC?bcjF$| z*iI(2a>r;UMWf8UlF&qCfHM3A7e_`tAsRJz>QNqVu%00F87AAZ_8ZWPNsX}_la-)I zx^KFqLZXtGoCl>P%WY7-a0?6gz-g<)M{Z-RyaoKG*fu8 z0{JzPR)l&V^|aM!OQRj|SRV>-Ty_bP8D5t+s4f7;)0PvS68}6+K3FP}JneWyx2N{h zuq|rFy;;jq`wf9eHV8~8RjM0_;D9}+DvUCPOC2w2K*!;7|_$l9+XcG>dyM?&f+e2mE*=+Y(Etm)jaK1VCa&ymi z_clB5fK5*{MxJ+Eh9Z1(f^fBaG?K(7d2uanLPCV!R1R)cKbUZx2T3`Y5Q#bj1k7Ck zn5>qb-E})1G-U22-frwrH`z#9$ew{XZ9I{E2jdA<$5j1{3G2RX9&$Dm%&>Np#oI5S zc-y*p^TTafyq%WT4-`cu<6rzH;r*~?-2OHdjbjV`?|3%;M|=)#n(ExIlcd*ndo~g{ z>!slWh)~nHUJu2gY->eH&!=3#N@YrE+QyTe@5imIUwkQLiguRV+xKPDdZfAac90B= zuS?vm4|`v4N0BPBkr+hfm_bnntuTrzhh#Knyy#GPh&fIiW+aA2!63q!rX`T(bDLA< zL*(Mbe}KJ7+>QHF9#T-)B6+;W(GUapk9<3@kfeBsVkDse%1v4vN+C`oj2l#3G&8Xf z>W>2|?FH|`bLb$#SO<-3p4sHrkQ!V&bj88?V8IQ?YaUh%a70H5V6xG01!bE9e4-Ei zqZ$gX#sN7Hf z2&VA+1ixMxLkkB3&(nQ1P>ag%X`kRM>rWK%J|=KDrB%1W9QvR&l7W0Ehc9Zy5f%{?TCjGnwr7ek&(Zs!!#GZ6hWsT=HB3J&MU7?|z@q(z=|;Qf z9Edz*D~crG3>psHVPU4%W%PXt@*=Ax;|ynr06Zh^RflEKgQbpnM>s3%T9QE&oWm)A zsvqBmSr{A6?1b2+#j>&9DjT9i9FVyRI@GK0eG3cn)kcIJD61Br+Dg$F}3Y@kk4lM4_rwi1g zN^~|*Rq3m`nWGnREZOOGqXcnx3R!RR#qPAcaa$y-Jx@F!+Nu?*eprt5UmMj_fw~WL z63baw8eZ3sT?0#TRy&xk5%*yZhs2U{n3fs|=Mq1!;e!V6`eKV@L`V<~RbpWARF@>} zcz=jt!CX+6Qsq=eEvp5!pk=h2zLL^aZQ@_%i@9P>Enp@S&m`~YUlfZA_+D3M)H!uU zn^E=D@N50wwT%9yo~OB@Hlv+O6|{o-?Nm-(QH!ef8`@`UKPy~XQ*1tDGgu(q@t!CxfmV!3p%R~OY-!8 z^x!{o)vV*zTEb-FCcTwH=@F)7mKo+BB(1z~YNT=rb*ek1m`2nt4M6HX_Yn*Z|L@>7 z#S`y`7z{n9&1prwsAf{yh?<598^6+jg8QG%Xxe|_{r}A9+J9(i4Jy^XEN8WU%jC2# y@Lxy%b*Z5JOKwX0=Uh?yr`)vmkC_?mAGBHR@0aGZpK2xT?`F!{-)i%7KlyK{_|QoJ literal 0 HcmV?d00001 diff --git a/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py new file mode 100644 index 0000000..43b2fe8 --- /dev/null +++ b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py @@ -0,0 +1,143 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py +# with the following modifications: +# - + +from typing import Optional, Tuple, Union + +import math +import torch + +from diffusers.utils import randn_tensor +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler + + +def ddim_step_with_logprob( + self: DDIMScheduler, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + assert isinstance(self, DDIMScheduler) + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) + alpha_prod_t = self.alphas_cumprod.gather(0, timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) + print(timestep) + print(alpha_prod_t) + print(alpha_prod_t_prev) + print(prev_timestep) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return prev_sample, log_prob diff --git a/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py new file mode 100644 index 0000000..09378c2 --- /dev/null +++ b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py @@ -0,0 +1,225 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +# with the following modifications: +# - + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipeline, + rescale_noise_cfg, +) +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from .ddim_with_logprob import ddim_step_with_logprob + + +@torch.no_grad() +def pipeline_with_logprob( + self: StableDiffusionPipeline, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image, has_nsfw_concept, all_latents, all_log_probs diff --git a/ddpo_pytorch/prompts.py b/ddpo_pytorch/prompts.py new file mode 100644 index 0000000..8cecf28 --- /dev/null +++ b/ddpo_pytorch/prompts.py @@ -0,0 +1,54 @@ +from importlib import resources +import functools +import random +import inflect + +IE = inflect.engine() +ASSETS_PATH = resources.files("ddpo_pytorch.assets") + + +@functools.cache +def load_lines(name): + with ASSETS_PATH.joinpath(name).open() as f: + return [line.strip() for line in f.readlines()] + + +def imagenet(low, high): + return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} + + +def imagenet_all(): + return imagenet(0, 1000) + + +def imagenet_animals(): + return imagenet(0, 398) + + +def imagenet_dogs(): + return imagenet(151, 269) + + +def nouns_activities(nouns_file, activities_file): + nouns = load_lines(nouns_file) + activities = load_lines(activities_file) + return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} + + +def counting(nouns_file, low, high): + nouns = load_lines(nouns_file) + number = IE.number_to_words(random.randint(low, high)) + noun = random.choice(nouns) + plural_noun = IE.plural(noun) + prompt = f"{number} {plural_noun}" + metadata = { + "questions": [ + f"How many {plural_noun} are there in this image?", + f"What animal is in this image?", + ], + "answers": [ + number, + noun, + ], + } + return prompt, metadata diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py new file mode 100644 index 0000000..9ec6218 --- /dev/null +++ b/ddpo_pytorch/rewards.py @@ -0,0 +1,29 @@ +from PIL import Image +import io +import numpy as np +import torch + + +def jpeg_incompressibility(): + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + buffers = [io.BytesIO() for _ in images] + for image, buffer in zip(images, buffers): + image.save(buffer, format="JPEG", quality=95) + sizes = [buffer.tell() / 1000 for buffer in buffers] + return np.array(sizes), {} + + return _fn + + +def jpeg_compressibility(): + jpeg_fn = jpeg_incompressibility() + + def _fn(images, prompts, metadata): + rew, meta = jpeg_fn(images, prompts, metadata) + return -rew, meta + + return _fn diff --git a/ddpo_pytorch/stat_tracking.py b/ddpo_pytorch/stat_tracking.py new file mode 100644 index 0000000..4199ab9 --- /dev/null +++ b/ddpo_pytorch/stat_tracking.py @@ -0,0 +1,34 @@ +import numpy as np +from collections import deque + + +class PerPromptStatTracker: + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return { + k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} + for k, v in self.stats.items() + } diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..c123ba7 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,341 @@ +from absl import app, flags, logging +from ml_collections import config_flags +from accelerate import Accelerator +from accelerate.utils import set_seed +from accelerate.logging import get_logger +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +import ddpo_pytorch.prompts +import ddpo_pytorch.rewards +from ddpo_pytorch.stat_tracking import PerPromptStatTracker +from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob +from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob +import torch +import tqdm + + +FLAGS = flags.FLAGS +config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") + +logger = get_logger(__name__) + + +def main(_): + # basic Accelerate and logging setup + config = FLAGS.config + accelerator = Accelerator( + log_with="all", + mixed_precision=config.mixed_precision, + project_dir=config.logdir, + ) + if accelerator.is_main_process: + accelerator.init_trackers(project_name="ddpo-pytorch", config=config) + logger.info(config) + + # set seed + set_seed(config.seed) + + # load scheduler, tokenizer and models. + pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) + # freeze parameters of models to save more memory + pipeline.unet.requires_grad_(False) + pipeline.vae.requires_grad_(False) + pipeline.text_encoder.requires_grad_(False) + # disable safety checker + pipeline.safety_checker = None + # make the progress bar nicer + pipeline.set_progress_bar_config( + position=1, + disable=not accelerator.is_local_main_process, + leave=False, + ) + # switch to DDIM scheduler + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + pipeline.unet.to(accelerator.device, dtype=weight_dtype) + pipeline.vae.to(accelerator.device, dtype=weight_dtype) + pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Set correct lora layers + lora_attn_procs = {} + for name in pipeline.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipeline.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + + pipeline.unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if config.train.scale_lr: + config.train.learning_rate = ( + config.train.learning_rate + * config.train.gradient_accumulation_steps + * config.train.batch_size + * accelerator.num_processes + ) + + # Initialize the optimizer + if config.train.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=config.train.learning_rate, + betas=(config.train.adam_beta1, config.train.adam_beta2), + weight_decay=config.train.adam_weight_decay, + eps=config.train.adam_epsilon, + ) + + # prepare prompt and reward fn + prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) + reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() + + # Prepare everything with our `accelerator`. + lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) + + # Train! + samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch + total_train_batch_size = ( + config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps + ) + + assert config.sample.batch_size % config.train.batch_size == 0 + assert samples_per_epoch % total_train_batch_size == 0 + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {config.num_epochs}") + logger.info(f" Sample batch size per device = {config.sample.batch_size}") + logger.info(f" Train batch size per device = {config.train.batch_size}") + logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") + logger.info("") + logger.info(f" Total number of samples per epoch = {samples_per_epoch}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") + logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}") + + neg_prompt_embed = pipeline.text_encoder( + pipeline.tokenizer( + [""], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + )[0] + sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) + train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) + + if config.per_prompt_stat_tracking: + stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking.buffer_size, + config.per_prompt_stat_tracking.min_count, + ) + + for epoch in range(config.num_epochs): + #################### SAMPLING #################### + samples = [] + prompts = [] + for i in tqdm.tqdm( + range(config.sample.num_batches_per_epoch), + desc=f"Epoch {epoch}: sampling", + disable=not accelerator.is_local_main_process, + position=0, + ): + # generate prompts + prompts, prompt_metadata = zip( + *[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)] + ) + + # encode prompts + prompt_ids = pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + prompt_embeds = pipeline.text_encoder(prompt_ids)[0] + + # sample + pipeline.unet.eval() + pipeline.vae.eval() + images, _, latents, log_probs = pipeline_with_logprob( + pipeline, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=config.sample.num_steps, + guidance_scale=config.sample.guidance_scale, + eta=config.sample.eta, + output_type="pt", + ) + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps) + + # compute rewards + rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "rewards": torch.as_tensor(rewards), + } + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + + # gather rewards across processes + rewards = accelerator.gather(samples["rewards"]).cpu().numpy() + + # per-prompt mean/std tracking + if config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; we only need to keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(accelerator.num_processes, -1)[accelerator.process_index] + .to(accelerator.device) + ) + + del samples["rewards"] + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch + assert num_timesteps == config.sample.num_steps + + #################### TRAINING #################### + for inner_epoch in range(config.train.num_inner_epochs): + # shuffle samples along batch dimension + indices = torch.randperm(total_batch_size, device=accelerator.device) + samples = {k: v[indices] for k, v in samples.items()} + + # shuffle along time dimension, independently for each sample + for i in range(total_batch_size): + indices = torch.randperm(num_timesteps, device=accelerator.device) + for key in ["timesteps", "latents", "next_latents"]: + samples[key][i] = samples[key][i][indices] + + # rebatch for training + samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} + + # dict of lists -> list of dicts for easier iteration + samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] + + # train + for i, sample in tqdm.tqdm( + list(enumerate(samples_batched)), + desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", + position=0, + ): + if config.train.cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in tqdm.trange( + num_timesteps, + desc=f"Timestep", + position=1, + leave=False, + ): + with accelerator.accumulate(pipeline.unet): + if config.train.cfg: + noise_pred = pipeline.unet( + torch.cat([sample["latents"][:, j]] * 2), + torch.cat([sample["timesteps"][:, j]] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = pipeline.unet( + sample["latents"][:, j], sample["timesteps"][:, j], embeds + ).sample + _, log_prob = ddim_step_with_logprob( + pipeline.scheduler, + noise_pred, + sample["timesteps"][:, j], + sample["latents"][:, j], + eta=config.sample.eta, + prev_sample=sample["next_latents"][:, j], + ) + + # ppo logic + advantages = torch.clamp( + sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max + ) + ratio = torch.exp(log_prob - sample["log_probs"][:, j]) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range + ) + loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + # debugging values + info = {} + # John Schulman says that (ratio - 1) - log(ratio) is a better + # estimator, but most existing code uses this so... + # http://joschu.net/blog/kl-approx.html + info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) + info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) + info["loss"] = loss + + # backward pass + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + +if __name__ == "__main__": + app.run(main) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..076cb2a --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import setup, find_packages + +setup( + name='ddpo-pytorch', + version='0.0.1', + packages=["ddpo_pytorch"], + install_requires=[ + "ml-collections", "absl-py" + ], +) \ No newline at end of file