Spaces:
Running
on
Zero
Running
on
Zero
daidedou
commited on
Commit
·
df60d6b
1
Parent(s):
1245229
forgot this
Browse files- edm/Dockerfile +18 -0
- edm/LICENSE.txt +439 -0
- edm/README.md +246 -0
- edm/dataset_tool.py +440 -0
- edm/dnnlib/__init__.py +8 -0
- edm/dnnlib/util.py +491 -0
- edm/environment.yml +19 -0
- edm/example.py +94 -0
- edm/fid.py +165 -0
- edm/generate.py +316 -0
- edm/torch_utils/__init__.py +8 -0
- edm/torch_utils/distributed.py +59 -0
- edm/torch_utils/misc.py +266 -0
- edm/torch_utils/persistence.py +257 -0
- edm/torch_utils/training_stats.py +272 -0
- edm/train.py +236 -0
- edm/training/__init__.py +8 -0
- edm/training/augment.py +330 -0
- edm/training/dataset.py +250 -0
- edm/training/loss.py +82 -0
- edm/training/networks.py +673 -0
- edm/training/training_loop.py +216 -0
edm/Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
FROM nvcr.io/nvidia/pytorch:22.10-py3
|
| 9 |
+
|
| 10 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
| 11 |
+
ENV PYTHONUNBUFFERED 1
|
| 12 |
+
|
| 13 |
+
RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
|
| 14 |
+
|
| 15 |
+
WORKDIR /workspace
|
| 16 |
+
|
| 17 |
+
RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
|
| 18 |
+
ENTRYPOINT ["/entry.sh"]
|
edm/LICENSE.txt
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
|
| 3 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
| 4 |
+
|
| 5 |
+
=======================================================================
|
| 6 |
+
|
| 7 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 8 |
+
does not provide legal services or legal advice. Distribution of
|
| 9 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 10 |
+
other relationship. Creative Commons makes its licenses and related
|
| 11 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 12 |
+
warranties regarding its licenses, any material licensed under their
|
| 13 |
+
terms and conditions, or any related information. Creative Commons
|
| 14 |
+
disclaims all liability for damages resulting from their use to the
|
| 15 |
+
fullest extent possible.
|
| 16 |
+
|
| 17 |
+
Using Creative Commons Public Licenses
|
| 18 |
+
|
| 19 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 20 |
+
conditions that creators and other rights holders may use to share
|
| 21 |
+
original works of authorship and other material subject to copyright
|
| 22 |
+
and certain other rights specified in the public license below. The
|
| 23 |
+
following considerations are for informational purposes only, are not
|
| 24 |
+
exhaustive, and do not form part of our licenses.
|
| 25 |
+
|
| 26 |
+
Considerations for licensors: Our public licenses are
|
| 27 |
+
intended for use by those authorized to give the public
|
| 28 |
+
permission to use material in ways otherwise restricted by
|
| 29 |
+
copyright and certain other rights. Our licenses are
|
| 30 |
+
irrevocable. Licensors should read and understand the terms
|
| 31 |
+
and conditions of the license they choose before applying it.
|
| 32 |
+
Licensors should also secure all rights necessary before
|
| 33 |
+
applying our licenses so that the public can reuse the
|
| 34 |
+
material as expected. Licensors should clearly mark any
|
| 35 |
+
material not subject to the license. This includes other CC-
|
| 36 |
+
licensed material, or material used under an exception or
|
| 37 |
+
limitation to copyright. More considerations for licensors:
|
| 38 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 39 |
+
|
| 40 |
+
Considerations for the public: By using one of our public
|
| 41 |
+
licenses, a licensor grants the public permission to use the
|
| 42 |
+
licensed material under specified terms and conditions. If
|
| 43 |
+
the licensor's permission is not necessary for any reason--for
|
| 44 |
+
example, because of any applicable exception or limitation to
|
| 45 |
+
copyright--then that use is not regulated by the license. Our
|
| 46 |
+
licenses grant only permissions under copyright and certain
|
| 47 |
+
other rights that a licensor has authority to grant. Use of
|
| 48 |
+
the licensed material may still be restricted for other
|
| 49 |
+
reasons, including because others have copyright or other
|
| 50 |
+
rights in the material. A licensor may make special requests,
|
| 51 |
+
such as asking that all changes be marked or described.
|
| 52 |
+
Although not required by our licenses, you are encouraged to
|
| 53 |
+
respect those requests where reasonable. More considerations
|
| 54 |
+
for the public:
|
| 55 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 56 |
+
|
| 57 |
+
=======================================================================
|
| 58 |
+
|
| 59 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 60 |
+
Public License
|
| 61 |
+
|
| 62 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 63 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 64 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 65 |
+
("Public License"). To the extent this Public License may be
|
| 66 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
| 67 |
+
consideration of Your acceptance of these terms and conditions, and the
|
| 68 |
+
Licensor grants You such rights in consideration of benefits the
|
| 69 |
+
Licensor receives from making the Licensed Material available under
|
| 70 |
+
these terms and conditions.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
Section 1 -- Definitions.
|
| 74 |
+
|
| 75 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 76 |
+
Rights that is derived from or based upon the Licensed Material
|
| 77 |
+
and in which the Licensed Material is translated, altered,
|
| 78 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 79 |
+
permission under the Copyright and Similar Rights held by the
|
| 80 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 81 |
+
Material is a musical work, performance, or sound recording,
|
| 82 |
+
Adapted Material is always produced where the Licensed Material is
|
| 83 |
+
synched in timed relation with a moving image.
|
| 84 |
+
|
| 85 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 86 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 87 |
+
accordance with the terms and conditions of this Public License.
|
| 88 |
+
|
| 89 |
+
c. BY-NC-SA Compatible License means a license listed at
|
| 90 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
| 91 |
+
Commons as essentially the equivalent of this Public License.
|
| 92 |
+
|
| 93 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
| 94 |
+
closely related to copyright including, without limitation,
|
| 95 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 96 |
+
Rights, without regard to how the rights are labeled or
|
| 97 |
+
categorized. For purposes of this Public License, the rights
|
| 98 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 99 |
+
Rights.
|
| 100 |
+
|
| 101 |
+
e. Effective Technological Measures means those measures that, in the
|
| 102 |
+
absence of proper authority, may not be circumvented under laws
|
| 103 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 104 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 105 |
+
agreements.
|
| 106 |
+
|
| 107 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 108 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 109 |
+
that applies to Your use of the Licensed Material.
|
| 110 |
+
|
| 111 |
+
g. License Elements means the license attributes listed in the name
|
| 112 |
+
of a Creative Commons Public License. The License Elements of this
|
| 113 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
| 114 |
+
|
| 115 |
+
h. Licensed Material means the artistic or literary work, database,
|
| 116 |
+
or other material to which the Licensor applied this Public
|
| 117 |
+
License.
|
| 118 |
+
|
| 119 |
+
i. Licensed Rights means the rights granted to You subject to the
|
| 120 |
+
terms and conditions of this Public License, which are limited to
|
| 121 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 122 |
+
Licensed Material and that the Licensor has authority to license.
|
| 123 |
+
|
| 124 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
| 125 |
+
under this Public License.
|
| 126 |
+
|
| 127 |
+
k. NonCommercial means not primarily intended for or directed towards
|
| 128 |
+
commercial advantage or monetary compensation. For purposes of
|
| 129 |
+
this Public License, the exchange of the Licensed Material for
|
| 130 |
+
other material subject to Copyright and Similar Rights by digital
|
| 131 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 132 |
+
no payment of monetary compensation in connection with the
|
| 133 |
+
exchange.
|
| 134 |
+
|
| 135 |
+
l. Share means to provide material to the public by any means or
|
| 136 |
+
process that requires permission under the Licensed Rights, such
|
| 137 |
+
as reproduction, public display, public performance, distribution,
|
| 138 |
+
dissemination, communication, or importation, and to make material
|
| 139 |
+
available to the public including in ways that members of the
|
| 140 |
+
public may access the material from a place and at a time
|
| 141 |
+
individually chosen by them.
|
| 142 |
+
|
| 143 |
+
m. Sui Generis Database Rights means rights other than copyright
|
| 144 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 145 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 146 |
+
as amended and/or succeeded, as well as other essentially
|
| 147 |
+
equivalent rights anywhere in the world.
|
| 148 |
+
|
| 149 |
+
n. You means the individual or entity exercising the Licensed Rights
|
| 150 |
+
under this Public License. Your has a corresponding meaning.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Section 2 -- Scope.
|
| 154 |
+
|
| 155 |
+
a. License grant.
|
| 156 |
+
|
| 157 |
+
1. Subject to the terms and conditions of this Public License,
|
| 158 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 159 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 160 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 161 |
+
|
| 162 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 163 |
+
in part, for NonCommercial purposes only; and
|
| 164 |
+
|
| 165 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 166 |
+
NonCommercial purposes only.
|
| 167 |
+
|
| 168 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 169 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 170 |
+
License does not apply, and You do not need to comply with
|
| 171 |
+
its terms and conditions.
|
| 172 |
+
|
| 173 |
+
3. Term. The term of this Public License is specified in Section
|
| 174 |
+
6(a).
|
| 175 |
+
|
| 176 |
+
4. Media and formats; technical modifications allowed. The
|
| 177 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 178 |
+
all media and formats whether now known or hereafter created,
|
| 179 |
+
and to make technical modifications necessary to do so. The
|
| 180 |
+
Licensor waives and/or agrees not to assert any right or
|
| 181 |
+
authority to forbid You from making technical modifications
|
| 182 |
+
necessary to exercise the Licensed Rights, including
|
| 183 |
+
technical modifications necessary to circumvent Effective
|
| 184 |
+
Technological Measures. For purposes of this Public License,
|
| 185 |
+
simply making modifications authorized by this Section 2(a)
|
| 186 |
+
(4) never produces Adapted Material.
|
| 187 |
+
|
| 188 |
+
5. Downstream recipients.
|
| 189 |
+
|
| 190 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 191 |
+
recipient of the Licensed Material automatically
|
| 192 |
+
receives an offer from the Licensor to exercise the
|
| 193 |
+
Licensed Rights under the terms and conditions of this
|
| 194 |
+
Public License.
|
| 195 |
+
|
| 196 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
| 197 |
+
Every recipient of Adapted Material from You
|
| 198 |
+
automatically receives an offer from the Licensor to
|
| 199 |
+
exercise the Licensed Rights in the Adapted Material
|
| 200 |
+
under the conditions of the Adapter's License You apply.
|
| 201 |
+
|
| 202 |
+
c. No downstream restrictions. You may not offer or impose
|
| 203 |
+
any additional or different terms or conditions on, or
|
| 204 |
+
apply any Effective Technological Measures to, the
|
| 205 |
+
Licensed Material if doing so restricts exercise of the
|
| 206 |
+
Licensed Rights by any recipient of the Licensed
|
| 207 |
+
Material.
|
| 208 |
+
|
| 209 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 210 |
+
may be construed as permission to assert or imply that You
|
| 211 |
+
are, or that Your use of the Licensed Material is, connected
|
| 212 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 213 |
+
the Licensor or others designated to receive attribution as
|
| 214 |
+
provided in Section 3(a)(1)(A)(i).
|
| 215 |
+
|
| 216 |
+
b. Other rights.
|
| 217 |
+
|
| 218 |
+
1. Moral rights, such as the right of integrity, are not
|
| 219 |
+
licensed under this Public License, nor are publicity,
|
| 220 |
+
privacy, and/or other similar personality rights; however, to
|
| 221 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 222 |
+
assert any such rights held by the Licensor to the limited
|
| 223 |
+
extent necessary to allow You to exercise the Licensed
|
| 224 |
+
Rights, but not otherwise.
|
| 225 |
+
|
| 226 |
+
2. Patent and trademark rights are not licensed under this
|
| 227 |
+
Public License.
|
| 228 |
+
|
| 229 |
+
3. To the extent possible, the Licensor waives any right to
|
| 230 |
+
collect royalties from You for the exercise of the Licensed
|
| 231 |
+
Rights, whether directly or through a collecting society
|
| 232 |
+
under any voluntary or waivable statutory or compulsory
|
| 233 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 234 |
+
reserves any right to collect such royalties, including when
|
| 235 |
+
the Licensed Material is used other than for NonCommercial
|
| 236 |
+
purposes.
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
Section 3 -- License Conditions.
|
| 240 |
+
|
| 241 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 242 |
+
following conditions.
|
| 243 |
+
|
| 244 |
+
a. Attribution.
|
| 245 |
+
|
| 246 |
+
1. If You Share the Licensed Material (including in modified
|
| 247 |
+
form), You must:
|
| 248 |
+
|
| 249 |
+
a. retain the following if it is supplied by the Licensor
|
| 250 |
+
with the Licensed Material:
|
| 251 |
+
|
| 252 |
+
i. identification of the creator(s) of the Licensed
|
| 253 |
+
Material and any others designated to receive
|
| 254 |
+
attribution, in any reasonable manner requested by
|
| 255 |
+
the Licensor (including by pseudonym if
|
| 256 |
+
designated);
|
| 257 |
+
|
| 258 |
+
ii. a copyright notice;
|
| 259 |
+
|
| 260 |
+
iii. a notice that refers to this Public License;
|
| 261 |
+
|
| 262 |
+
iv. a notice that refers to the disclaimer of
|
| 263 |
+
warranties;
|
| 264 |
+
|
| 265 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 266 |
+
extent reasonably practicable;
|
| 267 |
+
|
| 268 |
+
b. indicate if You modified the Licensed Material and
|
| 269 |
+
retain an indication of any previous modifications; and
|
| 270 |
+
|
| 271 |
+
c. indicate the Licensed Material is licensed under this
|
| 272 |
+
Public License, and include the text of, or the URI or
|
| 273 |
+
hyperlink to, this Public License.
|
| 274 |
+
|
| 275 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 276 |
+
reasonable manner based on the medium, means, and context in
|
| 277 |
+
which You Share the Licensed Material. For example, it may be
|
| 278 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 279 |
+
hyperlink to a resource that includes the required
|
| 280 |
+
information.
|
| 281 |
+
3. If requested by the Licensor, You must remove any of the
|
| 282 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 283 |
+
reasonably practicable.
|
| 284 |
+
|
| 285 |
+
b. ShareAlike.
|
| 286 |
+
|
| 287 |
+
In addition to the conditions in Section 3(a), if You Share
|
| 288 |
+
Adapted Material You produce, the following conditions also apply.
|
| 289 |
+
|
| 290 |
+
1. The Adapter's License You apply must be a Creative Commons
|
| 291 |
+
license with the same License Elements, this version or
|
| 292 |
+
later, or a BY-NC-SA Compatible License.
|
| 293 |
+
|
| 294 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
| 295 |
+
Adapter's License You apply. You may satisfy this condition
|
| 296 |
+
in any reasonable manner based on the medium, means, and
|
| 297 |
+
context in which You Share Adapted Material.
|
| 298 |
+
|
| 299 |
+
3. You may not offer or impose any additional or different terms
|
| 300 |
+
or conditions on, or apply any Effective Technological
|
| 301 |
+
Measures to, Adapted Material that restrict exercise of the
|
| 302 |
+
rights granted under the Adapter's License You apply.
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
Section 4 -- Sui Generis Database Rights.
|
| 306 |
+
|
| 307 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 308 |
+
apply to Your use of the Licensed Material:
|
| 309 |
+
|
| 310 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 311 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 312 |
+
portion of the contents of the database for NonCommercial purposes
|
| 313 |
+
only;
|
| 314 |
+
|
| 315 |
+
b. if You include all or a substantial portion of the database
|
| 316 |
+
contents in a database in which You have Sui Generis Database
|
| 317 |
+
Rights, then the database in which You have Sui Generis Database
|
| 318 |
+
Rights (but not its individual contents) is Adapted Material,
|
| 319 |
+
including for purposes of Section 3(b); and
|
| 320 |
+
|
| 321 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 322 |
+
all or a substantial portion of the contents of the database.
|
| 323 |
+
|
| 324 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 325 |
+
replace Your obligations under this Public License where the Licensed
|
| 326 |
+
Rights include other Copyright and Similar Rights.
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 330 |
+
|
| 331 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 332 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 333 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 334 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 335 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 336 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 337 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 338 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 339 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 340 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 341 |
+
|
| 342 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 343 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 344 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 345 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 346 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 347 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 348 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 349 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 350 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 351 |
+
|
| 352 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 353 |
+
above shall be interpreted in a manner that, to the extent
|
| 354 |
+
possible, most closely approximates an absolute disclaimer and
|
| 355 |
+
waiver of all liability.
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
Section 6 -- Term and Termination.
|
| 359 |
+
|
| 360 |
+
a. This Public License applies for the term of the Copyright and
|
| 361 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 362 |
+
this Public License, then Your rights under this Public License
|
| 363 |
+
terminate automatically.
|
| 364 |
+
|
| 365 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 366 |
+
Section 6(a), it reinstates:
|
| 367 |
+
|
| 368 |
+
1. automatically as of the date the violation is cured, provided
|
| 369 |
+
it is cured within 30 days of Your discovery of the
|
| 370 |
+
violation; or
|
| 371 |
+
|
| 372 |
+
2. upon express reinstatement by the Licensor.
|
| 373 |
+
|
| 374 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 375 |
+
right the Licensor may have to seek remedies for Your violations
|
| 376 |
+
of this Public License.
|
| 377 |
+
|
| 378 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 379 |
+
Licensed Material under separate terms or conditions or stop
|
| 380 |
+
distributing the Licensed Material at any time; however, doing so
|
| 381 |
+
will not terminate this Public License.
|
| 382 |
+
|
| 383 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 384 |
+
License.
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
Section 7 -- Other Terms and Conditions.
|
| 388 |
+
|
| 389 |
+
a. The Licensor shall not be bound by any additional or different
|
| 390 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 391 |
+
|
| 392 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 393 |
+
Licensed Material not stated herein are separate from and
|
| 394 |
+
independent of the terms and conditions of this Public License.
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
Section 8 -- Interpretation.
|
| 398 |
+
|
| 399 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 400 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 401 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 402 |
+
be made without permission under this Public License.
|
| 403 |
+
|
| 404 |
+
b. To the extent possible, if any provision of this Public License is
|
| 405 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 406 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 407 |
+
cannot be reformed, it shall be severed from this Public License
|
| 408 |
+
without affecting the enforceability of the remaining terms and
|
| 409 |
+
conditions.
|
| 410 |
+
|
| 411 |
+
c. No term or condition of this Public License will be waived and no
|
| 412 |
+
failure to comply consented to unless expressly agreed to by the
|
| 413 |
+
Licensor.
|
| 414 |
+
|
| 415 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 416 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 417 |
+
that apply to the Licensor or You, including from the legal
|
| 418 |
+
processes of any jurisdiction or authority.
|
| 419 |
+
|
| 420 |
+
=======================================================================
|
| 421 |
+
|
| 422 |
+
Creative Commons is not a party to its public
|
| 423 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 424 |
+
its public licenses to material it publishes and in those instances
|
| 425 |
+
will be considered the "Licensor." The text of the Creative Commons
|
| 426 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 427 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 428 |
+
material is shared under a Creative Commons public license or as
|
| 429 |
+
otherwise permitted by the Creative Commons policies published at
|
| 430 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 431 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 432 |
+
of Creative Commons without its prior written consent including,
|
| 433 |
+
without limitation, in connection with any unauthorized modifications
|
| 434 |
+
to any of its public licenses or any other arrangements,
|
| 435 |
+
understandings, or agreements concerning use of licensed material. For
|
| 436 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 437 |
+
public licenses.
|
| 438 |
+
|
| 439 |
+
Creative Commons may be contacted at creativecommons.org.
|
edm/README.md
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Elucidating the Design Space of Diffusion-Based Generative Models (EDM)<br><sub>Official PyTorch implementation of the NeurIPS 2022 paper</sub>
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
**Elucidating the Design Space of Diffusion-Based Generative Models**<br>
|
| 6 |
+
Tero Karras, Miika Aittala, Timo Aila, Samuli Laine
|
| 7 |
+
<br>https://arxiv.org/abs/2206.00364<br>
|
| 8 |
+
|
| 9 |
+
Abstract: *We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
|
| 10 |
+
|
| 11 |
+
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
|
| 12 |
+
|
| 13 |
+
## Requirements
|
| 14 |
+
|
| 15 |
+
* Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
|
| 16 |
+
* 1+ high-end NVIDIA GPU for sampling and 8+ GPUs for training. We have done all testing and development using V100 and A100 GPUs.
|
| 17 |
+
* 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch install instructions.
|
| 18 |
+
* Python libraries: See [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
|
| 19 |
+
- `conda env create -f environment.yml -n edm`
|
| 20 |
+
- `conda activate edm`
|
| 21 |
+
* Docker users:
|
| 22 |
+
- Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu).
|
| 23 |
+
- Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies.
|
| 24 |
+
|
| 25 |
+
## Getting started
|
| 26 |
+
|
| 27 |
+
To reproduce the main results from our paper, simply run:
|
| 28 |
+
|
| 29 |
+
```.bash
|
| 30 |
+
python example.py
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
This is a minimal standalone script that loads the best pre-trained model for each dataset and generates a random 8x8 grid of images using the optimal sampler settings. Expected results:
|
| 34 |
+
|
| 35 |
+
| Dataset | Runtime | Reference image
|
| 36 |
+
| :------- | :------ | :--------------
|
| 37 |
+
| CIFAR-10 | ~6 sec | [`cifar10-32x32.png`](./docs/cifar10-32x32.png)
|
| 38 |
+
| FFHQ | ~28 sec | [`ffhq-64x64.png`](./docs/ffhq-64x64.png)
|
| 39 |
+
| AFHQv2 | ~28 sec | [`afhqv2-64x64.png`](./docs/afhqv2-64x64.png)
|
| 40 |
+
| ImageNet | ~5 min | [`imagenet-64x64.png`](./docs/imagenet-64x64.png)
|
| 41 |
+
|
| 42 |
+
The easiest way to explore different sampling strategies is to modify [`example.py`](./example.py) directly. You can also incorporate the pre-trained models and/or our proposed EDM sampler in your own code by simply copy-pasting the relevant bits. Note that the class definitions for the pre-trained models are stored within the pickles themselves and loaded automatically during unpickling via [`torch_utils.persistence`](./torch_utils/persistence.py). To use the models in external Python scripts, just make sure that `torch_utils` and `dnnlib` are accesible through `PYTHONPATH`.
|
| 43 |
+
|
| 44 |
+
**Docker**: You can run the example script using Docker as follows:
|
| 45 |
+
|
| 46 |
+
```.bash
|
| 47 |
+
# Build the edm:latest image
|
| 48 |
+
docker build --tag edm:latest .
|
| 49 |
+
|
| 50 |
+
# Run the generate.py script using Docker:
|
| 51 |
+
docker run --gpus all -it --rm --user $(id -u):$(id -g) \
|
| 52 |
+
-v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \
|
| 53 |
+
edm:latest \
|
| 54 |
+
python example.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Note: The Docker image requires NVIDIA driver release `r520` or later.
|
| 58 |
+
|
| 59 |
+
The `docker run` invocation may look daunting, so let's unpack its contents here:
|
| 60 |
+
|
| 61 |
+
- `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root.
|
| 62 |
+
- ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir.
|
| 63 |
+
- `-e HOME=/scratch`: specify where to cache temporary files. Note: if you want more fine-grained control, you can instead set `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations.
|
| 64 |
+
|
| 65 |
+
## Pre-trained models
|
| 66 |
+
|
| 67 |
+
We provide pre-trained models for our proposed training configuration (config F) as well as the baseline configuration (config A):
|
| 68 |
+
|
| 69 |
+
- [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/)
|
| 70 |
+
- [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/)
|
| 71 |
+
|
| 72 |
+
To generate a batch of images using a given model and sampler, run:
|
| 73 |
+
|
| 74 |
+
```.bash
|
| 75 |
+
# Generate 64 images and save them as out/*.png
|
| 76 |
+
python generate.py --outdir=out --seeds=0-63 --batch=64 \
|
| 77 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`:
|
| 81 |
+
|
| 82 |
+
```.bash
|
| 83 |
+
# Generate 1024 images using 2 GPUs
|
| 84 |
+
torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \
|
| 85 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
The sampler settings can be controlled through command-line options; see [`python generate.py --help`](./docs/generate-help.txt) for more information. For best results, we recommend using the following settings for each dataset:
|
| 89 |
+
|
| 90 |
+
```.bash
|
| 91 |
+
# For CIFAR-10 at 32x32, use deterministic sampling with 18 steps (NFE = 35)
|
| 92 |
+
python generate.py --outdir=out --steps=18 \
|
| 93 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 94 |
+
|
| 95 |
+
# For FFHQ and AFHQv2 at 64x64, use deterministic sampling with 40 steps (NFE = 79)
|
| 96 |
+
python generate.py --outdir=out --steps=40 \
|
| 97 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-vp.pkl
|
| 98 |
+
|
| 99 |
+
# For ImageNet at 64x64, use stochastic sampling with 256 steps (NFE = 511)
|
| 100 |
+
python generate.py --outdir=out --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003 \
|
| 101 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Besides our proposed EDM sampler, `generate.py` can also be used to reproduce the sampler ablations from Section 3 of our paper. For example:
|
| 105 |
+
|
| 106 |
+
```.bash
|
| 107 |
+
# Figure 2a, "Our reimplementation"
|
| 108 |
+
python generate.py --outdir=out --steps=512 --solver=euler --disc=vp --schedule=vp --scaling=vp \
|
| 109 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
|
| 110 |
+
|
| 111 |
+
# Figure 2a, "+ Heun & our {t_i}"
|
| 112 |
+
python generate.py --outdir=out --steps=128 --solver=heun --disc=edm --schedule=vp --scaling=vp \
|
| 113 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
|
| 114 |
+
|
| 115 |
+
# Figure 2a, "+ Our sigma(t) & s(t)"
|
| 116 |
+
python generate.py --outdir=out --steps=18 --solver=heun --disc=edm --schedule=linear --scaling=none \
|
| 117 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## Calculating FID
|
| 121 |
+
|
| 122 |
+
To compute Fréchet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`:
|
| 123 |
+
|
| 124 |
+
```.bash
|
| 125 |
+
# Generate 50000 images and save them as fid-tmp/*/*.png
|
| 126 |
+
torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \
|
| 127 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 128 |
+
|
| 129 |
+
# Calculate FID
|
| 130 |
+
torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \
|
| 131 |
+
--ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Both of the above commands can be parallelized across multiple GPUs by adjusting `--nproc_per_node`. The second command typically takes 1-3 minutes in practice, but the first one can sometimes take several hours, depending on the configuration. See [`python fid.py --help`](./docs/fid-help.txt) for the full list of options.
|
| 135 |
+
|
| 136 |
+
Note that the numerical value of FID varies across different random seeds and is highly sensitive to the number of images. By default, `fid.py` will always use 50,000 generated images; providing fewer images will result in an error, whereas providing more will use a random subset. To reduce the effect of random variation, we recommend repeating the calculation multiple times with different seeds, e.g., `--seeds=0-49999`, `--seeds=50000-99999`, and `--seeds=100000-149999`. In our paper, we calculated each FID three times and reported the minimum.
|
| 137 |
+
|
| 138 |
+
Also note that it is important to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, we provide the exact reference statistics that correspond to our pre-trained models:
|
| 139 |
+
|
| 140 |
+
* [https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/)
|
| 141 |
+
|
| 142 |
+
For ImageNet, we provide two sets of reference statistics to enable apples-to-apples comparison: `imagenet-64x64.npz` should be used when evaluating the EDM model (`edm-imagenet-64x64-cond-adm.pkl`), whereas `imagenet-64x64-baseline.npz` should be used when evaluating the baseline model (`baseline-imagenet-64x64-cond-adm.pkl`); the latter was originally trained by Dhariwal and Nichol using slightly different training data.
|
| 143 |
+
|
| 144 |
+
You can compute the reference statistics for your own datasets as follows:
|
| 145 |
+
|
| 146 |
+
```.bash
|
| 147 |
+
python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## Preparing datasets
|
| 151 |
+
|
| 152 |
+
Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information.
|
| 153 |
+
|
| 154 |
+
**CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
|
| 155 |
+
|
| 156 |
+
```.bash
|
| 157 |
+
python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
|
| 158 |
+
--dest=datasets/cifar10-32x32.zip
|
| 159 |
+
python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
**FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution:
|
| 163 |
+
|
| 164 |
+
```.bash
|
| 165 |
+
python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
|
| 166 |
+
--dest=datasets/ffhq-64x64.zip --resolution=64x64
|
| 167 |
+
python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
**AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution:
|
| 171 |
+
|
| 172 |
+
```.bash
|
| 173 |
+
python dataset_tool.py --source=downloads/afhqv2 \
|
| 174 |
+
--dest=datasets/afhqv2-64x64.zip --resolution=64x64
|
| 175 |
+
python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
**ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution:
|
| 179 |
+
|
| 180 |
+
```.bash
|
| 181 |
+
python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
|
| 182 |
+
--dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
|
| 183 |
+
python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## Training new models
|
| 187 |
+
|
| 188 |
+
You can train new models using `train.py`. For example:
|
| 189 |
+
|
| 190 |
+
```.bash
|
| 191 |
+
# Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
|
| 192 |
+
torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \
|
| 193 |
+
--data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
The above example uses the default batch size of 512 images (controlled by `--batch`) that is divided evenly among 8 GPUs (controlled by `--nproc_per_node`) to yield 64 images per GPU. Training large models may run out of GPU memory; the best way to avoid this is to limit the per-GPU batch size, e.g., `--batch-gpu=32`. This employs gradient accumulation to yield the same results as using full per-GPU batches. See [`python train.py --help`](./docs/train-help.txt) for the full list of options.
|
| 197 |
+
|
| 198 |
+
The results of each training run are saved to a newly created directory, for example `training-runs/00000-cifar10-cond-ddpmpp-edm-gpus8-batch64-fp32`. The training loop exports network snapshots (`network-snapshot-*.pkl`) and training states (`training-state-*.pt`) at regular intervals (controlled by `--snap` and `--dump`). The network snapshots can be used to generate images with `generate.py`, and the training states can be used to resume the training later on (`--resume`). Other useful information is recorded in `log.txt` and `stats.jsonl`. To monitor training convergence, we recommend looking at the training loss (`"Loss/loss"` in `stats.jsonl`) as well as periodically evaluating FID for `network-snapshot-*.pkl` using `generate.py` and `fid.py`.
|
| 199 |
+
|
| 200 |
+
The following table lists the exact training configurations that we used to obtain our pre-trained models:
|
| 201 |
+
|
| 202 |
+
| <sub>Model</sub> | <sub>GPUs</sub> | <sub>Time</sub> | <sub>Options</sub>
|
| 203 |
+
| :-- | :-- | :-- | :--
|
| 204 |
+
| <sub>cifar10‑32x32‑cond‑vp</sub> | <sub>8xV100</sub> | <sub>~2 days</sub> | <sub>`--cond=1 --arch=ddpmpp`</sub>
|
| 205 |
+
| <sub>cifar10‑32x32‑cond‑ve</sub> | <sub>8xV100</sub> | <sub>~2 days</sub> | <sub>`--cond=1 --arch=ncsnpp`</sub>
|
| 206 |
+
| <sub>cifar10‑32x32‑uncond‑vp</sub> | <sub>8xV100</sub> | <sub>~2 days</sub> | <sub>`--cond=0 --arch=ddpmpp`</sub>
|
| 207 |
+
| <sub>cifar10‑32x32‑uncond‑ve</sub> | <sub>8xV100</sub> | <sub>~2 days</sub> | <sub>`--cond=0 --arch=ncsnpp`</sub>
|
| 208 |
+
| <sub>ffhq‑64x64‑uncond‑vp</sub> | <sub>8xV100</sub> | <sub>~4 days</sub> | <sub>`--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`</sub>
|
| 209 |
+
| <sub>ffhq‑64x64‑uncond‑ve</sub> | <sub>8xV100</sub> | <sub>~4 days</sub> | <sub>`--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`</sub>
|
| 210 |
+
| <sub>afhqv2‑64x64‑uncond‑vp</sub> | <sub>8xV100</sub> | <sub>~4 days</sub> | <sub>`--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`</sub>
|
| 211 |
+
| <sub>afhqv2‑64x64‑uncond‑ve</sub> | <sub>8xV100</sub> | <sub>~4 days</sub> | <sub>`--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`</sub>
|
| 212 |
+
| <sub>imagenet‑64x64‑cond‑adm</sub> | <sub>32xA100</sub> | <sub>~13 days</sub> | <sub>`--cond=1 --arch=adm --duration=2500 --batch=4096 --lr=1e-4 --ema=50 --dropout=0.10 --augment=0 --fp16=1 --ls=100 --tick=200`</sub>
|
| 213 |
+
|
| 214 |
+
For ImageNet-64, we ran the training on four NVIDIA DGX A100 nodes, each containing 8 Ampere GPUs with 80 GB of memory. To reduce the GPU memory requirements, we recommend either training the model with more GPUs or limiting the per-GPU batch size with `--batch-gpu`. To set up multi-node training, please consult the [torchrun documentation](https://pytorch.org/docs/stable/elastic/run.html).
|
| 215 |
+
|
| 216 |
+
## License
|
| 217 |
+
|
| 218 |
+
Copyright © 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 219 |
+
|
| 220 |
+
All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).
|
| 221 |
+
|
| 222 |
+
`baseline-cifar10-32x32-uncond-vp.pkl` and `baseline-cifar10-32x32-uncond-ve.pkl` are derived from the [pre-trained models](https://github.com/yang-song/score_sde_pytorch) by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. The models were originally shared under the [Apache 2.0 license](https://github.com/yang-song/score_sde_pytorch/blob/main/LICENSE).
|
| 223 |
+
|
| 224 |
+
`baseline-imagenet-64x64-cond-adm.pkl` is derived from the [pre-trained model](https://github.com/openai/guided-diffusion) by Prafulla Dhariwal and Alex Nichol. The model was originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
|
| 225 |
+
|
| 226 |
+
`imagenet-64x64-baseline.npz` is derived from the [precomputed reference statistics](https://github.com/openai/guided-diffusion/tree/main/evaluations) by Prafulla Dhariwal and Alex Nichol. The statistics were
|
| 227 |
+
originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
|
| 228 |
+
|
| 229 |
+
## Citation
|
| 230 |
+
|
| 231 |
+
```
|
| 232 |
+
@inproceedings{Karras2022edm,
|
| 233 |
+
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
|
| 234 |
+
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
|
| 235 |
+
booktitle = {Proc. NeurIPS},
|
| 236 |
+
year = {2022}
|
| 237 |
+
}
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
## Development
|
| 241 |
+
|
| 242 |
+
This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
|
| 243 |
+
|
| 244 |
+
## Acknowledgments
|
| 245 |
+
|
| 246 |
+
We thank Jaakko Lehtinen, Ming-Yu Liu, Tuomas Kynkäänniemi, Axel Sauer, Arash Vahdat, and Janne Hellsten for discussions and comments, and Tero Kuosmanen, Samuel Klenberg, and Janne Hellsten for maintaining our compute infrastructure.
|
edm/dataset_tool.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Tool for creating ZIP/PNG based datasets."""
|
| 9 |
+
|
| 10 |
+
import functools
|
| 11 |
+
import gzip
|
| 12 |
+
import io
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import pickle
|
| 16 |
+
import re
|
| 17 |
+
import sys
|
| 18 |
+
import tarfile
|
| 19 |
+
import zipfile
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Callable, Optional, Tuple, Union
|
| 22 |
+
import click
|
| 23 |
+
import numpy as np
|
| 24 |
+
import PIL.Image
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
#----------------------------------------------------------------------------
|
| 28 |
+
# Parse a 'M,N' or 'MxN' integer tuple.
|
| 29 |
+
# Example: '4x2' returns (4,2)
|
| 30 |
+
|
| 31 |
+
def parse_tuple(s: str) -> Tuple[int, int]:
|
| 32 |
+
m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| 33 |
+
if m:
|
| 34 |
+
return int(m.group(1)), int(m.group(2))
|
| 35 |
+
raise click.ClickException(f'cannot parse tuple {s}')
|
| 36 |
+
|
| 37 |
+
#----------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
| 40 |
+
if b is not None:
|
| 41 |
+
return min(a, b)
|
| 42 |
+
return a
|
| 43 |
+
|
| 44 |
+
#----------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
def file_ext(name: Union[str, Path]) -> str:
|
| 47 |
+
return str(name).split('.')[-1]
|
| 48 |
+
|
| 49 |
+
#----------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
| 52 |
+
ext = file_ext(fname).lower()
|
| 53 |
+
return f'.{ext}' in PIL.Image.EXTENSION
|
| 54 |
+
|
| 55 |
+
#----------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]):
|
| 58 |
+
input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
|
| 59 |
+
arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| 60 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 61 |
+
|
| 62 |
+
# Load labels.
|
| 63 |
+
labels = dict()
|
| 64 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
| 65 |
+
if os.path.isfile(meta_fname):
|
| 66 |
+
with open(meta_fname, 'r') as file:
|
| 67 |
+
data = json.load(file)['labels']
|
| 68 |
+
if data is not None:
|
| 69 |
+
labels = {x[0]: x[1] for x in data}
|
| 70 |
+
|
| 71 |
+
# No labels available => determine from top-level directory names.
|
| 72 |
+
if len(labels) == 0:
|
| 73 |
+
toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| 74 |
+
toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| 75 |
+
if len(toplevel_indices) > 1:
|
| 76 |
+
labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
| 77 |
+
|
| 78 |
+
def iterate_images():
|
| 79 |
+
for idx, fname in enumerate(input_images):
|
| 80 |
+
img = np.array(PIL.Image.open(fname))
|
| 81 |
+
yield dict(img=img, label=labels.get(arch_fnames.get(fname)))
|
| 82 |
+
if idx >= max_idx - 1:
|
| 83 |
+
break
|
| 84 |
+
return max_idx, iterate_images()
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def open_image_zip(source, *, max_images: Optional[int]):
|
| 89 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 90 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| 91 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 92 |
+
|
| 93 |
+
# Load labels.
|
| 94 |
+
labels = dict()
|
| 95 |
+
if 'dataset.json' in z.namelist():
|
| 96 |
+
with z.open('dataset.json', 'r') as file:
|
| 97 |
+
data = json.load(file)['labels']
|
| 98 |
+
if data is not None:
|
| 99 |
+
labels = {x[0]: x[1] for x in data}
|
| 100 |
+
|
| 101 |
+
def iterate_images():
|
| 102 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 103 |
+
for idx, fname in enumerate(input_images):
|
| 104 |
+
with z.open(fname, 'r') as file:
|
| 105 |
+
img = np.array(PIL.Image.open(file))
|
| 106 |
+
yield dict(img=img, label=labels.get(fname))
|
| 107 |
+
if idx >= max_idx - 1:
|
| 108 |
+
break
|
| 109 |
+
return max_idx, iterate_images()
|
| 110 |
+
|
| 111 |
+
#----------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
|
| 114 |
+
import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python
|
| 115 |
+
import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb
|
| 116 |
+
|
| 117 |
+
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
|
| 118 |
+
max_idx = maybe_min(txn.stat()['entries'], max_images)
|
| 119 |
+
|
| 120 |
+
def iterate_images():
|
| 121 |
+
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
|
| 122 |
+
for idx, (_key, value) in enumerate(txn.cursor()):
|
| 123 |
+
try:
|
| 124 |
+
try:
|
| 125 |
+
img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
|
| 126 |
+
if img is None:
|
| 127 |
+
raise IOError('cv2.imdecode failed')
|
| 128 |
+
img = img[:, :, ::-1] # BGR => RGB
|
| 129 |
+
except IOError:
|
| 130 |
+
img = np.array(PIL.Image.open(io.BytesIO(value)))
|
| 131 |
+
yield dict(img=img, label=None)
|
| 132 |
+
if idx >= max_idx - 1:
|
| 133 |
+
break
|
| 134 |
+
except:
|
| 135 |
+
print(sys.exc_info()[1])
|
| 136 |
+
|
| 137 |
+
return max_idx, iterate_images()
|
| 138 |
+
|
| 139 |
+
#----------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
def open_cifar10(tarball: str, *, max_images: Optional[int]):
|
| 142 |
+
images = []
|
| 143 |
+
labels = []
|
| 144 |
+
|
| 145 |
+
with tarfile.open(tarball, 'r:gz') as tar:
|
| 146 |
+
for batch in range(1, 6):
|
| 147 |
+
member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
|
| 148 |
+
with tar.extractfile(member) as file:
|
| 149 |
+
data = pickle.load(file, encoding='latin1')
|
| 150 |
+
images.append(data['data'].reshape(-1, 3, 32, 32))
|
| 151 |
+
labels.append(data['labels'])
|
| 152 |
+
|
| 153 |
+
images = np.concatenate(images)
|
| 154 |
+
labels = np.concatenate(labels)
|
| 155 |
+
images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
|
| 156 |
+
assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
|
| 157 |
+
assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
|
| 158 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
| 159 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
| 160 |
+
|
| 161 |
+
max_idx = maybe_min(len(images), max_images)
|
| 162 |
+
|
| 163 |
+
def iterate_images():
|
| 164 |
+
for idx, img in enumerate(images):
|
| 165 |
+
yield dict(img=img, label=int(labels[idx]))
|
| 166 |
+
if idx >= max_idx - 1:
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
return max_idx, iterate_images()
|
| 170 |
+
|
| 171 |
+
#----------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
| 174 |
+
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
|
| 175 |
+
assert labels_gz != images_gz
|
| 176 |
+
images = []
|
| 177 |
+
labels = []
|
| 178 |
+
|
| 179 |
+
with gzip.open(images_gz, 'rb') as f:
|
| 180 |
+
images = np.frombuffer(f.read(), np.uint8, offset=16)
|
| 181 |
+
with gzip.open(labels_gz, 'rb') as f:
|
| 182 |
+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
|
| 183 |
+
|
| 184 |
+
images = images.reshape(-1, 28, 28)
|
| 185 |
+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
| 186 |
+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
|
| 187 |
+
assert labels.shape == (60000,) and labels.dtype == np.uint8
|
| 188 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
| 189 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
| 190 |
+
|
| 191 |
+
max_idx = maybe_min(len(images), max_images)
|
| 192 |
+
|
| 193 |
+
def iterate_images():
|
| 194 |
+
for idx, img in enumerate(images):
|
| 195 |
+
yield dict(img=img, label=int(labels[idx]))
|
| 196 |
+
if idx >= max_idx - 1:
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
return max_idx, iterate_images()
|
| 200 |
+
|
| 201 |
+
#----------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def make_transform(
|
| 204 |
+
transform: Optional[str],
|
| 205 |
+
output_width: Optional[int],
|
| 206 |
+
output_height: Optional[int]
|
| 207 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| 208 |
+
def scale(width, height, img):
|
| 209 |
+
w = img.shape[1]
|
| 210 |
+
h = img.shape[0]
|
| 211 |
+
if width == w and height == h:
|
| 212 |
+
return img
|
| 213 |
+
img = PIL.Image.fromarray(img)
|
| 214 |
+
ww = width if width is not None else w
|
| 215 |
+
hh = height if height is not None else h
|
| 216 |
+
img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| 217 |
+
return np.array(img)
|
| 218 |
+
|
| 219 |
+
def center_crop(width, height, img):
|
| 220 |
+
crop = np.min(img.shape[:2])
|
| 221 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| 222 |
+
if img.ndim == 2:
|
| 223 |
+
img = img[:, :, np.newaxis].repeat(3, axis=2)
|
| 224 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 225 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 226 |
+
return np.array(img)
|
| 227 |
+
|
| 228 |
+
def center_crop_wide(width, height, img):
|
| 229 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| 230 |
+
if img.shape[1] < width or ch < height:
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| 234 |
+
if img.ndim == 2:
|
| 235 |
+
img = img[:, :, np.newaxis].repeat(3, axis=2)
|
| 236 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 237 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 238 |
+
img = np.array(img)
|
| 239 |
+
|
| 240 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| 241 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| 242 |
+
return canvas
|
| 243 |
+
|
| 244 |
+
if transform is None:
|
| 245 |
+
return functools.partial(scale, output_width, output_height)
|
| 246 |
+
if transform == 'center-crop':
|
| 247 |
+
if output_width is None or output_height is None:
|
| 248 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| 249 |
+
return functools.partial(center_crop, output_width, output_height)
|
| 250 |
+
if transform == 'center-crop-wide':
|
| 251 |
+
if output_width is None or output_height is None:
|
| 252 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 253 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
| 254 |
+
assert False, 'unknown transform'
|
| 255 |
+
|
| 256 |
+
#----------------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
| 259 |
+
if os.path.isdir(source):
|
| 260 |
+
if source.rstrip('/').endswith('_lmdb'):
|
| 261 |
+
return open_lmdb(source, max_images=max_images)
|
| 262 |
+
else:
|
| 263 |
+
return open_image_folder(source, max_images=max_images)
|
| 264 |
+
elif os.path.isfile(source):
|
| 265 |
+
if os.path.basename(source) == 'cifar-10-python.tar.gz':
|
| 266 |
+
return open_cifar10(source, max_images=max_images)
|
| 267 |
+
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
|
| 268 |
+
return open_mnist(source, max_images=max_images)
|
| 269 |
+
elif file_ext(source) == 'zip':
|
| 270 |
+
return open_image_zip(source, max_images=max_images)
|
| 271 |
+
else:
|
| 272 |
+
assert False, 'unknown archive type'
|
| 273 |
+
else:
|
| 274 |
+
raise click.ClickException(f'Missing input file or directory: {source}')
|
| 275 |
+
|
| 276 |
+
#----------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| 279 |
+
dest_ext = file_ext(dest)
|
| 280 |
+
|
| 281 |
+
if dest_ext == 'zip':
|
| 282 |
+
if os.path.dirname(dest) != '':
|
| 283 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| 284 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| 285 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| 286 |
+
zf.writestr(fname, data)
|
| 287 |
+
return '', zip_write_bytes, zf.close
|
| 288 |
+
else:
|
| 289 |
+
# If the output folder already exists, check that is is
|
| 290 |
+
# empty.
|
| 291 |
+
#
|
| 292 |
+
# Note: creating the output directory is not strictly
|
| 293 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
| 294 |
+
# to give an error message earlier in case the dest folder
|
| 295 |
+
# somehow cannot be created.
|
| 296 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| 297 |
+
raise click.ClickException('--dest folder must be empty')
|
| 298 |
+
os.makedirs(dest, exist_ok=True)
|
| 299 |
+
|
| 300 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| 301 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| 302 |
+
with open(fname, 'wb') as fout:
|
| 303 |
+
if isinstance(data, str):
|
| 304 |
+
data = data.encode('utf8')
|
| 305 |
+
fout.write(data)
|
| 306 |
+
return dest, folder_write_bytes, lambda: None
|
| 307 |
+
|
| 308 |
+
#----------------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
@click.command()
|
| 311 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 312 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 313 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 314 |
+
@click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide']))
|
| 315 |
+
@click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
|
| 316 |
+
|
| 317 |
+
def main(
|
| 318 |
+
source: str,
|
| 319 |
+
dest: str,
|
| 320 |
+
max_images: Optional[int],
|
| 321 |
+
transform: Optional[str],
|
| 322 |
+
resolution: Optional[Tuple[int, int]]
|
| 323 |
+
):
|
| 324 |
+
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
|
| 325 |
+
|
| 326 |
+
The input dataset format is guessed from the --source argument:
|
| 327 |
+
|
| 328 |
+
\b
|
| 329 |
+
--source *_lmdb/ Load LSUN dataset
|
| 330 |
+
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
|
| 331 |
+
--source train-images-idx3-ubyte.gz Load MNIST dataset
|
| 332 |
+
--source path/ Recursively load all images from path/
|
| 333 |
+
--source dataset.zip Recursively load all images from dataset.zip
|
| 334 |
+
|
| 335 |
+
Specifying the output format and path:
|
| 336 |
+
|
| 337 |
+
\b
|
| 338 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
| 339 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
| 340 |
+
|
| 341 |
+
The output dataset format can be either an image folder or an uncompressed zip archive.
|
| 342 |
+
Zip archives makes it easier to move datasets around file servers and clusters, and may
|
| 343 |
+
offer better training performance on network file systems.
|
| 344 |
+
|
| 345 |
+
Images within the dataset archive will be stored as uncompressed PNG.
|
| 346 |
+
Uncompresed PNGs can be efficiently decoded in the training loop.
|
| 347 |
+
|
| 348 |
+
Class labels are stored in a file called 'dataset.json' that is stored at the
|
| 349 |
+
dataset root folder. This file has the following structure:
|
| 350 |
+
|
| 351 |
+
\b
|
| 352 |
+
{
|
| 353 |
+
"labels": [
|
| 354 |
+
["00000/img00000000.png",6],
|
| 355 |
+
["00000/img00000001.png",9],
|
| 356 |
+
... repeated for every image in the datase
|
| 357 |
+
["00049/img00049999.png",1]
|
| 358 |
+
]
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
If the 'dataset.json' file cannot be found, class labels are determined from
|
| 362 |
+
top-level directory names.
|
| 363 |
+
|
| 364 |
+
Image scale/crop and resolution requirements:
|
| 365 |
+
|
| 366 |
+
Output images must be square-shaped and they must all have the same power-of-two
|
| 367 |
+
dimensions.
|
| 368 |
+
|
| 369 |
+
To scale arbitrary input image size to a specific width and height, use the
|
| 370 |
+
--resolution option. Output resolution will be either the original
|
| 371 |
+
input resolution (if resolution was not specified) or the one specified with
|
| 372 |
+
--resolution option.
|
| 373 |
+
|
| 374 |
+
Use the --transform=center-crop or --transform=center-crop-wide options to apply a
|
| 375 |
+
center crop transform on the input image. These options should be used with the
|
| 376 |
+
--resolution option. For example:
|
| 377 |
+
|
| 378 |
+
\b
|
| 379 |
+
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
|
| 380 |
+
--transform=center-crop-wide --resolution=512x384
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
PIL.Image.init()
|
| 384 |
+
|
| 385 |
+
if dest == '':
|
| 386 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 387 |
+
|
| 388 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 389 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 390 |
+
|
| 391 |
+
if resolution is None: resolution = (None, None)
|
| 392 |
+
transform_image = make_transform(transform, *resolution)
|
| 393 |
+
|
| 394 |
+
dataset_attrs = None
|
| 395 |
+
|
| 396 |
+
labels = []
|
| 397 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 398 |
+
idx_str = f'{idx:08d}'
|
| 399 |
+
archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
|
| 400 |
+
|
| 401 |
+
# Apply crop and resize.
|
| 402 |
+
img = transform_image(image['img'])
|
| 403 |
+
if img is None:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
# Error check to require uniform image attributes across
|
| 407 |
+
# the whole dataset.
|
| 408 |
+
channels = img.shape[2] if img.ndim == 3 else 1
|
| 409 |
+
cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels}
|
| 410 |
+
if dataset_attrs is None:
|
| 411 |
+
dataset_attrs = cur_image_attrs
|
| 412 |
+
width = dataset_attrs['width']
|
| 413 |
+
height = dataset_attrs['height']
|
| 414 |
+
if width != height:
|
| 415 |
+
raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
|
| 416 |
+
if dataset_attrs['channels'] not in [1, 3]:
|
| 417 |
+
raise click.ClickException('Input images must be stored as RGB or grayscale')
|
| 418 |
+
if width != 2 ** int(np.floor(np.log2(width))):
|
| 419 |
+
raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
|
| 420 |
+
elif dataset_attrs != cur_image_attrs:
|
| 421 |
+
err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
|
| 422 |
+
raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
|
| 423 |
+
|
| 424 |
+
# Save the image as an uncompressed PNG.
|
| 425 |
+
img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
|
| 426 |
+
image_bits = io.BytesIO()
|
| 427 |
+
img.save(image_bits, format='png', compress_level=0, optimize=False)
|
| 428 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
|
| 429 |
+
labels.append([archive_fname, image['label']] if image['label'] is not None else None)
|
| 430 |
+
|
| 431 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 432 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 433 |
+
close_dest()
|
| 434 |
+
|
| 435 |
+
#----------------------------------------------------------------------------
|
| 436 |
+
|
| 437 |
+
if __name__ == "__main__":
|
| 438 |
+
main()
|
| 439 |
+
|
| 440 |
+
#----------------------------------------------------------------------------
|
edm/dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
from .util import EasyDict, make_cache_dir_path
|
edm/dnnlib/util.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous utility classes and functions."""
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import fnmatch
|
| 12 |
+
import importlib
|
| 13 |
+
import inspect
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
import io
|
| 20 |
+
import pickle
|
| 21 |
+
import re
|
| 22 |
+
import requests
|
| 23 |
+
import html
|
| 24 |
+
import hashlib
|
| 25 |
+
import glob
|
| 26 |
+
import tempfile
|
| 27 |
+
import urllib
|
| 28 |
+
import urllib.request
|
| 29 |
+
import uuid
|
| 30 |
+
|
| 31 |
+
from distutils.util import strtobool
|
| 32 |
+
from typing import Any, List, Tuple, Union, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Util classes
|
| 36 |
+
# ------------------------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class EasyDict(dict):
|
| 40 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 41 |
+
|
| 42 |
+
def __getattr__(self, name: str) -> Any:
|
| 43 |
+
try:
|
| 44 |
+
return self[name]
|
| 45 |
+
except KeyError:
|
| 46 |
+
raise AttributeError(name)
|
| 47 |
+
|
| 48 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 49 |
+
self[name] = value
|
| 50 |
+
|
| 51 |
+
def __delattr__(self, name: str) -> None:
|
| 52 |
+
del self[name]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Logger(object):
|
| 56 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
|
| 59 |
+
self.file = None
|
| 60 |
+
|
| 61 |
+
if file_name is not None:
|
| 62 |
+
self.file = open(file_name, file_mode)
|
| 63 |
+
|
| 64 |
+
self.should_flush = should_flush
|
| 65 |
+
self.stdout = sys.stdout
|
| 66 |
+
self.stderr = sys.stderr
|
| 67 |
+
|
| 68 |
+
sys.stdout = self
|
| 69 |
+
sys.stderr = self
|
| 70 |
+
|
| 71 |
+
def __enter__(self) -> "Logger":
|
| 72 |
+
return self
|
| 73 |
+
|
| 74 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 75 |
+
self.close()
|
| 76 |
+
|
| 77 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 78 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 79 |
+
if isinstance(text, bytes):
|
| 80 |
+
text = text.decode()
|
| 81 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
if self.file is not None:
|
| 85 |
+
self.file.write(text)
|
| 86 |
+
|
| 87 |
+
self.stdout.write(text)
|
| 88 |
+
|
| 89 |
+
if self.should_flush:
|
| 90 |
+
self.flush()
|
| 91 |
+
|
| 92 |
+
def flush(self) -> None:
|
| 93 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 94 |
+
if self.file is not None:
|
| 95 |
+
self.file.flush()
|
| 96 |
+
|
| 97 |
+
self.stdout.flush()
|
| 98 |
+
|
| 99 |
+
def close(self) -> None:
|
| 100 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 101 |
+
self.flush()
|
| 102 |
+
|
| 103 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 104 |
+
if sys.stdout is self:
|
| 105 |
+
sys.stdout = self.stdout
|
| 106 |
+
if sys.stderr is self:
|
| 107 |
+
sys.stderr = self.stderr
|
| 108 |
+
|
| 109 |
+
if self.file is not None:
|
| 110 |
+
self.file.close()
|
| 111 |
+
self.file = None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Cache directories
|
| 115 |
+
# ------------------------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
_dnnlib_cache_dir = None
|
| 118 |
+
|
| 119 |
+
def set_cache_dir(path: str) -> None:
|
| 120 |
+
global _dnnlib_cache_dir
|
| 121 |
+
_dnnlib_cache_dir = path
|
| 122 |
+
|
| 123 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 124 |
+
if _dnnlib_cache_dir is not None:
|
| 125 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 126 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 127 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 128 |
+
if 'HOME' in os.environ:
|
| 129 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 130 |
+
if 'USERPROFILE' in os.environ:
|
| 131 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 132 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 133 |
+
|
| 134 |
+
# Small util functions
|
| 135 |
+
# ------------------------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 139 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 140 |
+
s = int(np.rint(seconds))
|
| 141 |
+
|
| 142 |
+
if s < 60:
|
| 143 |
+
return "{0}s".format(s)
|
| 144 |
+
elif s < 60 * 60:
|
| 145 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 146 |
+
elif s < 24 * 60 * 60:
|
| 147 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 148 |
+
else:
|
| 149 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 153 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 154 |
+
s = int(np.rint(seconds))
|
| 155 |
+
|
| 156 |
+
if s < 60:
|
| 157 |
+
return "{0}s".format(s)
|
| 158 |
+
elif s < 60 * 60:
|
| 159 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 160 |
+
elif s < 24 * 60 * 60:
|
| 161 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 162 |
+
else:
|
| 163 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def ask_yes_no(question: str) -> bool:
|
| 167 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 168 |
+
while True:
|
| 169 |
+
try:
|
| 170 |
+
print("{0} [y/n]".format(question))
|
| 171 |
+
return strtobool(input().lower())
|
| 172 |
+
except ValueError:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def tuple_product(t: Tuple) -> Any:
|
| 177 |
+
"""Calculate the product of the tuple elements."""
|
| 178 |
+
result = 1
|
| 179 |
+
|
| 180 |
+
for v in t:
|
| 181 |
+
result *= v
|
| 182 |
+
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
_str_to_ctype = {
|
| 187 |
+
"uint8": ctypes.c_ubyte,
|
| 188 |
+
"uint16": ctypes.c_uint16,
|
| 189 |
+
"uint32": ctypes.c_uint32,
|
| 190 |
+
"uint64": ctypes.c_uint64,
|
| 191 |
+
"int8": ctypes.c_byte,
|
| 192 |
+
"int16": ctypes.c_int16,
|
| 193 |
+
"int32": ctypes.c_int32,
|
| 194 |
+
"int64": ctypes.c_int64,
|
| 195 |
+
"float32": ctypes.c_float,
|
| 196 |
+
"float64": ctypes.c_double
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 201 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 202 |
+
type_str = None
|
| 203 |
+
|
| 204 |
+
if isinstance(type_obj, str):
|
| 205 |
+
type_str = type_obj
|
| 206 |
+
elif hasattr(type_obj, "__name__"):
|
| 207 |
+
type_str = type_obj.__name__
|
| 208 |
+
elif hasattr(type_obj, "name"):
|
| 209 |
+
type_str = type_obj.name
|
| 210 |
+
else:
|
| 211 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 212 |
+
|
| 213 |
+
assert type_str in _str_to_ctype.keys()
|
| 214 |
+
|
| 215 |
+
my_dtype = np.dtype(type_str)
|
| 216 |
+
my_ctype = _str_to_ctype[type_str]
|
| 217 |
+
|
| 218 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 219 |
+
|
| 220 |
+
return my_dtype, my_ctype
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def is_pickleable(obj: Any) -> bool:
|
| 224 |
+
try:
|
| 225 |
+
with io.BytesIO() as stream:
|
| 226 |
+
pickle.dump(obj, stream)
|
| 227 |
+
return True
|
| 228 |
+
except:
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 233 |
+
# ------------------------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 236 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 237 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 238 |
+
|
| 239 |
+
# allow convenience shorthands, substitute them by full names
|
| 240 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 241 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 242 |
+
|
| 243 |
+
# list alternatives for (module_name, local_obj_name)
|
| 244 |
+
parts = obj_name.split(".")
|
| 245 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 246 |
+
|
| 247 |
+
# try each alternative in turn
|
| 248 |
+
for module_name, local_obj_name in name_pairs:
|
| 249 |
+
try:
|
| 250 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 251 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 252 |
+
return module, local_obj_name
|
| 253 |
+
except:
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
# maybe some of the modules themselves contain errors?
|
| 257 |
+
for module_name, _local_obj_name in name_pairs:
|
| 258 |
+
try:
|
| 259 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 260 |
+
except ImportError:
|
| 261 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 262 |
+
raise
|
| 263 |
+
|
| 264 |
+
# maybe the requested attribute is missing?
|
| 265 |
+
for module_name, local_obj_name in name_pairs:
|
| 266 |
+
try:
|
| 267 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 268 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 269 |
+
except ImportError:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# we are out of luck, but we have no idea why
|
| 273 |
+
raise ImportError(obj_name)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 277 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 278 |
+
if obj_name == '':
|
| 279 |
+
return module
|
| 280 |
+
obj = module
|
| 281 |
+
for part in obj_name.split("."):
|
| 282 |
+
obj = getattr(obj, part)
|
| 283 |
+
return obj
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def get_obj_by_name(name: str) -> Any:
|
| 287 |
+
"""Finds the python object with the given name."""
|
| 288 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 289 |
+
return get_obj_from_module(module, obj_name)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 293 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 294 |
+
assert func_name is not None
|
| 295 |
+
func_obj = get_obj_by_name(func_name)
|
| 296 |
+
assert callable(func_obj)
|
| 297 |
+
return func_obj(*args, **kwargs)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 301 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 302 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 306 |
+
"""Get the directory path of the module containing the given object name."""
|
| 307 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 308 |
+
return os.path.dirname(inspect.getfile(module))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 312 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 313 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 317 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 318 |
+
assert is_top_level_function(obj)
|
| 319 |
+
module = obj.__module__
|
| 320 |
+
if module == '__main__':
|
| 321 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 322 |
+
return module + "." + obj.__name__
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# File system helpers
|
| 326 |
+
# ------------------------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 329 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 330 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 331 |
+
assert os.path.isdir(dir_path)
|
| 332 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 333 |
+
|
| 334 |
+
if ignores is None:
|
| 335 |
+
ignores = []
|
| 336 |
+
|
| 337 |
+
result = []
|
| 338 |
+
|
| 339 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 340 |
+
for ignore_ in ignores:
|
| 341 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 342 |
+
|
| 343 |
+
# dirs need to be edited in-place
|
| 344 |
+
for d in dirs_to_remove:
|
| 345 |
+
dirs.remove(d)
|
| 346 |
+
|
| 347 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 348 |
+
|
| 349 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 350 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 351 |
+
|
| 352 |
+
if add_base_to_relative:
|
| 353 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 354 |
+
|
| 355 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 356 |
+
result += zip(absolute_paths, relative_paths)
|
| 357 |
+
|
| 358 |
+
return result
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 362 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 363 |
+
Will create all necessary directories."""
|
| 364 |
+
for file in files:
|
| 365 |
+
target_dir_name = os.path.dirname(file[1])
|
| 366 |
+
|
| 367 |
+
# will create all intermediate-level directories
|
| 368 |
+
if not os.path.exists(target_dir_name):
|
| 369 |
+
os.makedirs(target_dir_name)
|
| 370 |
+
|
| 371 |
+
shutil.copyfile(file[0], file[1])
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# URL helpers
|
| 375 |
+
# ------------------------------------------------------------------------------------------
|
| 376 |
+
|
| 377 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 378 |
+
"""Determine whether the given object is a valid URL string."""
|
| 379 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 380 |
+
return False
|
| 381 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 382 |
+
return True
|
| 383 |
+
try:
|
| 384 |
+
res = requests.compat.urlparse(obj)
|
| 385 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 386 |
+
return False
|
| 387 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 388 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 389 |
+
return False
|
| 390 |
+
except:
|
| 391 |
+
return False
|
| 392 |
+
return True
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 396 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 397 |
+
assert num_attempts >= 1
|
| 398 |
+
assert not (return_filename and (not cache))
|
| 399 |
+
|
| 400 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 401 |
+
if not re.match('^[a-z]+://', url):
|
| 402 |
+
return url if return_filename else open(url, "rb")
|
| 403 |
+
|
| 404 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 405 |
+
# arise on Windows:
|
| 406 |
+
#
|
| 407 |
+
# file:///c:/foo.txt
|
| 408 |
+
#
|
| 409 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 410 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 411 |
+
#
|
| 412 |
+
# If you touch this code path, you should test it on both Linux and
|
| 413 |
+
# Windows.
|
| 414 |
+
#
|
| 415 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 416 |
+
# but that converts forward slashes to backslashes and this causes
|
| 417 |
+
# its own set of problems.
|
| 418 |
+
if url.startswith('file://'):
|
| 419 |
+
filename = urllib.parse.urlparse(url).path
|
| 420 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 421 |
+
filename = filename[1:]
|
| 422 |
+
return filename if return_filename else open(filename, "rb")
|
| 423 |
+
|
| 424 |
+
assert is_url(url)
|
| 425 |
+
|
| 426 |
+
# Lookup from cache.
|
| 427 |
+
if cache_dir is None:
|
| 428 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 429 |
+
|
| 430 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 431 |
+
if cache:
|
| 432 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 433 |
+
if len(cache_files) == 1:
|
| 434 |
+
filename = cache_files[0]
|
| 435 |
+
return filename if return_filename else open(filename, "rb")
|
| 436 |
+
|
| 437 |
+
# Download.
|
| 438 |
+
url_name = None
|
| 439 |
+
url_data = None
|
| 440 |
+
with requests.Session() as session:
|
| 441 |
+
if verbose:
|
| 442 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 443 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 444 |
+
try:
|
| 445 |
+
with session.get(url) as res:
|
| 446 |
+
res.raise_for_status()
|
| 447 |
+
if len(res.content) == 0:
|
| 448 |
+
raise IOError("No data received")
|
| 449 |
+
|
| 450 |
+
if len(res.content) < 8192:
|
| 451 |
+
content_str = res.content.decode("utf-8")
|
| 452 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 453 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 454 |
+
if len(links) == 1:
|
| 455 |
+
url = requests.compat.urljoin(url, links[0])
|
| 456 |
+
raise IOError("Google Drive virus checker nag")
|
| 457 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 458 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 459 |
+
|
| 460 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 461 |
+
url_name = match[1] if match else url
|
| 462 |
+
url_data = res.content
|
| 463 |
+
if verbose:
|
| 464 |
+
print(" done")
|
| 465 |
+
break
|
| 466 |
+
except KeyboardInterrupt:
|
| 467 |
+
raise
|
| 468 |
+
except:
|
| 469 |
+
if not attempts_left:
|
| 470 |
+
if verbose:
|
| 471 |
+
print(" failed")
|
| 472 |
+
raise
|
| 473 |
+
if verbose:
|
| 474 |
+
print(".", end="", flush=True)
|
| 475 |
+
|
| 476 |
+
# Save to cache.
|
| 477 |
+
if cache:
|
| 478 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 479 |
+
safe_name = safe_name[:min(len(safe_name), 128)]
|
| 480 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 481 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 482 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 483 |
+
with open(temp_file, "wb") as f:
|
| 484 |
+
f.write(url_data)
|
| 485 |
+
os.replace(temp_file, cache_file) # atomic
|
| 486 |
+
if return_filename:
|
| 487 |
+
return cache_file
|
| 488 |
+
|
| 489 |
+
# Return data as file object.
|
| 490 |
+
assert not return_filename
|
| 491 |
+
return io.BytesIO(url_data)
|
edm/environment.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: edm
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
dependencies:
|
| 6 |
+
- python>=3.8, < 3.10 # package build failures on 3.10
|
| 7 |
+
- pip
|
| 8 |
+
- numpy>=1.20
|
| 9 |
+
- click>=8.0
|
| 10 |
+
- pillow>=8.3.1
|
| 11 |
+
- scipy>=1.7.1
|
| 12 |
+
- pytorch=1.12.1
|
| 13 |
+
- psutil
|
| 14 |
+
- requests
|
| 15 |
+
- tqdm
|
| 16 |
+
- imageio
|
| 17 |
+
- pip:
|
| 18 |
+
- imageio-ffmpeg>=0.4.3
|
| 19 |
+
- pyspng
|
edm/example.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Minimal standalone example to reproduce the main results from the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import tqdm
|
| 12 |
+
import pickle
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import PIL.Image
|
| 16 |
+
import dnnlib
|
| 17 |
+
|
| 18 |
+
#----------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
def generate_image_grid(
|
| 21 |
+
network_pkl, dest_path,
|
| 22 |
+
seed=0, gridw=8, gridh=8, device=torch.device('cuda'),
|
| 23 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
| 24 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
| 25 |
+
):
|
| 26 |
+
batch_size = gridw * gridh
|
| 27 |
+
torch.manual_seed(seed)
|
| 28 |
+
|
| 29 |
+
# Load network.
|
| 30 |
+
print(f'Loading network from "{network_pkl}"...')
|
| 31 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
| 32 |
+
net = pickle.load(f)['ema'].to(device)
|
| 33 |
+
|
| 34 |
+
# Pick latents and labels.
|
| 35 |
+
print(f'Generating {batch_size} images...')
|
| 36 |
+
latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
|
| 37 |
+
class_labels = None
|
| 38 |
+
if net.label_dim:
|
| 39 |
+
class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
|
| 40 |
+
|
| 41 |
+
# Adjust noise levels based on what's supported by the network.
|
| 42 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 43 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 44 |
+
|
| 45 |
+
# Time step discretization.
|
| 46 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
|
| 47 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 48 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 49 |
+
|
| 50 |
+
# Main sampling loop.
|
| 51 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
| 52 |
+
for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
|
| 53 |
+
x_cur = x_next
|
| 54 |
+
|
| 55 |
+
# Increase noise temporarily.
|
| 56 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
| 57 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
| 58 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
|
| 59 |
+
|
| 60 |
+
# Euler step.
|
| 61 |
+
denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
|
| 62 |
+
d_cur = (x_hat - denoised) / t_hat
|
| 63 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
| 64 |
+
|
| 65 |
+
# Apply 2nd order correction.
|
| 66 |
+
if i < num_steps - 1:
|
| 67 |
+
denoised = net(x_next, t_next, class_labels).to(torch.float64)
|
| 68 |
+
d_prime = (x_next - denoised) / t_next
|
| 69 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 70 |
+
|
| 71 |
+
# Save image grid.
|
| 72 |
+
print(f'Saving image grid to "{dest_path}"...')
|
| 73 |
+
image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
|
| 74 |
+
image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
|
| 75 |
+
image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
|
| 76 |
+
image = image.cpu().numpy()
|
| 77 |
+
PIL.Image.fromarray(image, 'RGB').save(dest_path)
|
| 78 |
+
print('Done.')
|
| 79 |
+
|
| 80 |
+
#----------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
|
| 84 |
+
generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35
|
| 85 |
+
generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79
|
| 86 |
+
generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79
|
| 87 |
+
generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511
|
| 88 |
+
|
| 89 |
+
#----------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
| 93 |
+
|
| 94 |
+
#----------------------------------------------------------------------------
|
edm/fid.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Script for calculating Frechet Inception Distance (FID)."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import click
|
| 12 |
+
import tqdm
|
| 13 |
+
import pickle
|
| 14 |
+
import numpy as np
|
| 15 |
+
import scipy.linalg
|
| 16 |
+
import torch
|
| 17 |
+
import dnnlib
|
| 18 |
+
from torch_utils import distributed as dist
|
| 19 |
+
from training import dataset
|
| 20 |
+
|
| 21 |
+
#----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def calculate_inception_stats(
|
| 24 |
+
image_path, num_expected=None, seed=0, max_batch_size=64,
|
| 25 |
+
num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
|
| 26 |
+
):
|
| 27 |
+
# Rank 0 goes first.
|
| 28 |
+
if dist.get_rank() != 0:
|
| 29 |
+
torch.distributed.barrier()
|
| 30 |
+
|
| 31 |
+
# Load Inception-v3 model.
|
| 32 |
+
# This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 33 |
+
dist.print0('Loading Inception-v3 model...')
|
| 34 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
| 35 |
+
detector_kwargs = dict(return_features=True)
|
| 36 |
+
feature_dim = 2048
|
| 37 |
+
with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
|
| 38 |
+
detector_net = pickle.load(f).to(device)
|
| 39 |
+
|
| 40 |
+
# List images.
|
| 41 |
+
dist.print0(f'Loading images from "{image_path}"...')
|
| 42 |
+
dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
|
| 43 |
+
if num_expected is not None and len(dataset_obj) < num_expected:
|
| 44 |
+
raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
|
| 45 |
+
if len(dataset_obj) < 2:
|
| 46 |
+
raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
|
| 47 |
+
|
| 48 |
+
# Other ranks follow.
|
| 49 |
+
if dist.get_rank() == 0:
|
| 50 |
+
torch.distributed.barrier()
|
| 51 |
+
|
| 52 |
+
# Divide images into batches.
|
| 53 |
+
num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
|
| 54 |
+
all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
|
| 55 |
+
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
|
| 56 |
+
data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
|
| 57 |
+
|
| 58 |
+
# Accumulate statistics.
|
| 59 |
+
dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
|
| 60 |
+
mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
|
| 61 |
+
sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
|
| 62 |
+
for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
|
| 63 |
+
torch.distributed.barrier()
|
| 64 |
+
if images.shape[0] == 0:
|
| 65 |
+
continue
|
| 66 |
+
if images.shape[1] == 1:
|
| 67 |
+
images = images.repeat([1, 3, 1, 1])
|
| 68 |
+
features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
|
| 69 |
+
mu += features.sum(0)
|
| 70 |
+
sigma += features.T @ features
|
| 71 |
+
|
| 72 |
+
# Calculate grand totals.
|
| 73 |
+
torch.distributed.all_reduce(mu)
|
| 74 |
+
torch.distributed.all_reduce(sigma)
|
| 75 |
+
mu /= len(dataset_obj)
|
| 76 |
+
sigma -= mu.ger(mu) * len(dataset_obj)
|
| 77 |
+
sigma /= len(dataset_obj) - 1
|
| 78 |
+
return mu.cpu().numpy(), sigma.cpu().numpy()
|
| 79 |
+
|
| 80 |
+
#----------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
|
| 83 |
+
m = np.square(mu - mu_ref).sum()
|
| 84 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
|
| 85 |
+
fid = m + np.trace(sigma + sigma_ref - s * 2)
|
| 86 |
+
return float(np.real(fid))
|
| 87 |
+
|
| 88 |
+
#----------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
@click.group()
|
| 91 |
+
def main():
|
| 92 |
+
"""Calculate Frechet Inception Distance (FID).
|
| 93 |
+
|
| 94 |
+
Examples:
|
| 95 |
+
|
| 96 |
+
\b
|
| 97 |
+
# Generate 50000 images and save them as fid-tmp/*/*.png
|
| 98 |
+
torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
|
| 99 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 100 |
+
|
| 101 |
+
\b
|
| 102 |
+
# Calculate FID
|
| 103 |
+
torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
|
| 104 |
+
--ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
|
| 105 |
+
|
| 106 |
+
\b
|
| 107 |
+
# Compute dataset reference statistics
|
| 108 |
+
python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
#----------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
@main.command()
|
| 114 |
+
@click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
|
| 115 |
+
@click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
|
| 116 |
+
@click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
|
| 117 |
+
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
|
| 118 |
+
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
|
| 119 |
+
|
| 120 |
+
def calc(image_path, ref_path, num_expected, seed, batch):
|
| 121 |
+
"""Calculate FID for a given set of images."""
|
| 122 |
+
torch.multiprocessing.set_start_method('spawn')
|
| 123 |
+
dist.init()
|
| 124 |
+
|
| 125 |
+
dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
|
| 126 |
+
ref = None
|
| 127 |
+
if dist.get_rank() == 0:
|
| 128 |
+
with dnnlib.util.open_url(ref_path) as f:
|
| 129 |
+
ref = dict(np.load(f))
|
| 130 |
+
|
| 131 |
+
mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
|
| 132 |
+
dist.print0('Calculating FID...')
|
| 133 |
+
if dist.get_rank() == 0:
|
| 134 |
+
fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
|
| 135 |
+
print(f'{fid:g}')
|
| 136 |
+
torch.distributed.barrier()
|
| 137 |
+
|
| 138 |
+
#----------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
@main.command()
|
| 141 |
+
@click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
|
| 142 |
+
@click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
|
| 143 |
+
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
|
| 144 |
+
|
| 145 |
+
def ref(dataset_path, dest_path, batch):
|
| 146 |
+
"""Calculate dataset reference statistics needed by 'calc'."""
|
| 147 |
+
torch.multiprocessing.set_start_method('spawn')
|
| 148 |
+
dist.init()
|
| 149 |
+
|
| 150 |
+
mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
|
| 151 |
+
dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
|
| 152 |
+
if dist.get_rank() == 0:
|
| 153 |
+
if os.path.dirname(dest_path):
|
| 154 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 155 |
+
np.savez(dest_path, mu=mu, sigma=sigma)
|
| 156 |
+
|
| 157 |
+
torch.distributed.barrier()
|
| 158 |
+
dist.print0('Done.')
|
| 159 |
+
|
| 160 |
+
#----------------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
| 164 |
+
|
| 165 |
+
#----------------------------------------------------------------------------
|
edm/generate.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Generate random images using the techniques described in the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import click
|
| 14 |
+
import tqdm
|
| 15 |
+
import pickle
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import PIL.Image
|
| 19 |
+
import dnnlib
|
| 20 |
+
from torch_utils import distributed as dist
|
| 21 |
+
|
| 22 |
+
#----------------------------------------------------------------------------
|
| 23 |
+
# Proposed EDM sampler (Algorithm 2).
|
| 24 |
+
|
| 25 |
+
def edm_sampler(
|
| 26 |
+
net, latents, class_labels=None, randn_like=torch.randn_like,
|
| 27 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
| 28 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
| 29 |
+
):
|
| 30 |
+
# Adjust noise levels based on what's supported by the network.
|
| 31 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 32 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 33 |
+
|
| 34 |
+
# Time step discretization.
|
| 35 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
| 36 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 37 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 38 |
+
|
| 39 |
+
# Main sampling loop.
|
| 40 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
| 41 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
| 42 |
+
x_cur = x_next
|
| 43 |
+
|
| 44 |
+
# Increase noise temporarily.
|
| 45 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
| 46 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
| 47 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
| 48 |
+
|
| 49 |
+
# Euler step.
|
| 50 |
+
denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
|
| 51 |
+
d_cur = (x_hat - denoised) / t_hat
|
| 52 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
| 53 |
+
|
| 54 |
+
# Apply 2nd order correction.
|
| 55 |
+
if i < num_steps - 1:
|
| 56 |
+
denoised = net(x_next, t_next, class_labels).to(torch.float64)
|
| 57 |
+
d_prime = (x_next - denoised) / t_next
|
| 58 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 59 |
+
|
| 60 |
+
return x_next
|
| 61 |
+
|
| 62 |
+
#----------------------------------------------------------------------------
|
| 63 |
+
# Generalized ablation sampler, representing the superset of all sampling
|
| 64 |
+
# methods discussed in the paper.
|
| 65 |
+
|
| 66 |
+
def ablation_sampler(
|
| 67 |
+
net, latents, class_labels=None, randn_like=torch.randn_like,
|
| 68 |
+
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
|
| 69 |
+
solver='heun', discretization='edm', schedule='linear', scaling='none',
|
| 70 |
+
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
|
| 71 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
| 72 |
+
):
|
| 73 |
+
assert solver in ['euler', 'heun']
|
| 74 |
+
assert discretization in ['vp', 've', 'iddpm', 'edm']
|
| 75 |
+
assert schedule in ['vp', 've', 'linear']
|
| 76 |
+
assert scaling in ['vp', 'none']
|
| 77 |
+
|
| 78 |
+
# Helper functions for VP & VE noise level schedules.
|
| 79 |
+
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
| 80 |
+
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
|
| 81 |
+
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
| 82 |
+
ve_sigma = lambda t: t.sqrt()
|
| 83 |
+
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
|
| 84 |
+
ve_sigma_inv = lambda sigma: sigma ** 2
|
| 85 |
+
|
| 86 |
+
# Select default noise level range based on the specified time step discretization.
|
| 87 |
+
if sigma_min is None:
|
| 88 |
+
vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
|
| 89 |
+
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
|
| 90 |
+
if sigma_max is None:
|
| 91 |
+
vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
|
| 92 |
+
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
|
| 93 |
+
|
| 94 |
+
# Adjust noise levels based on what's supported by the network.
|
| 95 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 96 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 97 |
+
|
| 98 |
+
# Compute corresponding betas for VP.
|
| 99 |
+
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
|
| 100 |
+
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
|
| 101 |
+
|
| 102 |
+
# Define time steps in terms of noise level.
|
| 103 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
| 104 |
+
if discretization == 'vp':
|
| 105 |
+
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
|
| 106 |
+
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
|
| 107 |
+
elif discretization == 've':
|
| 108 |
+
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
|
| 109 |
+
sigma_steps = ve_sigma(orig_t_steps)
|
| 110 |
+
elif discretization == 'iddpm':
|
| 111 |
+
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
|
| 112 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
| 113 |
+
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
|
| 114 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
| 115 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
| 116 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
| 117 |
+
else:
|
| 118 |
+
assert discretization == 'edm'
|
| 119 |
+
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 120 |
+
|
| 121 |
+
# Define noise level schedule.
|
| 122 |
+
if schedule == 'vp':
|
| 123 |
+
sigma = vp_sigma(vp_beta_d, vp_beta_min)
|
| 124 |
+
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
|
| 125 |
+
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
|
| 126 |
+
elif schedule == 've':
|
| 127 |
+
sigma = ve_sigma
|
| 128 |
+
sigma_deriv = ve_sigma_deriv
|
| 129 |
+
sigma_inv = ve_sigma_inv
|
| 130 |
+
else:
|
| 131 |
+
assert schedule == 'linear'
|
| 132 |
+
sigma = lambda t: t
|
| 133 |
+
sigma_deriv = lambda t: 1
|
| 134 |
+
sigma_inv = lambda sigma: sigma
|
| 135 |
+
|
| 136 |
+
# Define scaling schedule.
|
| 137 |
+
if scaling == 'vp':
|
| 138 |
+
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
|
| 139 |
+
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
|
| 140 |
+
else:
|
| 141 |
+
assert scaling == 'none'
|
| 142 |
+
s = lambda t: 1
|
| 143 |
+
s_deriv = lambda t: 0
|
| 144 |
+
|
| 145 |
+
# Compute final time steps based on the corresponding noise levels.
|
| 146 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
| 147 |
+
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 148 |
+
|
| 149 |
+
# Main sampling loop.
|
| 150 |
+
t_next = t_steps[0]
|
| 151 |
+
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
|
| 152 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
| 153 |
+
x_cur = x_next
|
| 154 |
+
|
| 155 |
+
# Increase noise temporarily.
|
| 156 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
|
| 157 |
+
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
|
| 158 |
+
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
|
| 159 |
+
|
| 160 |
+
# Euler step.
|
| 161 |
+
h = t_next - t_hat
|
| 162 |
+
denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
|
| 163 |
+
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
|
| 164 |
+
x_prime = x_hat + alpha * h * d_cur
|
| 165 |
+
t_prime = t_hat + alpha * h
|
| 166 |
+
|
| 167 |
+
# Apply 2nd order correction.
|
| 168 |
+
if solver == 'euler' or i == num_steps - 1:
|
| 169 |
+
x_next = x_hat + h * d_cur
|
| 170 |
+
else:
|
| 171 |
+
assert solver == 'heun'
|
| 172 |
+
denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
|
| 173 |
+
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
|
| 174 |
+
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
|
| 175 |
+
|
| 176 |
+
return x_next
|
| 177 |
+
|
| 178 |
+
#----------------------------------------------------------------------------
|
| 179 |
+
# Wrapper for torch.Generator that allows specifying a different random seed
|
| 180 |
+
# for each sample in a minibatch.
|
| 181 |
+
|
| 182 |
+
class StackedRandomGenerator:
|
| 183 |
+
def __init__(self, device, seeds):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
|
| 186 |
+
|
| 187 |
+
def randn(self, size, **kwargs):
|
| 188 |
+
assert size[0] == len(self.generators)
|
| 189 |
+
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
|
| 190 |
+
|
| 191 |
+
def randn_like(self, input):
|
| 192 |
+
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
|
| 193 |
+
|
| 194 |
+
def randint(self, *args, size, **kwargs):
|
| 195 |
+
assert size[0] == len(self.generators)
|
| 196 |
+
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
|
| 197 |
+
|
| 198 |
+
#----------------------------------------------------------------------------
|
| 199 |
+
# Parse a comma separated list of numbers or ranges and return a list of ints.
|
| 200 |
+
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
|
| 201 |
+
|
| 202 |
+
def parse_int_list(s):
|
| 203 |
+
if isinstance(s, list): return s
|
| 204 |
+
ranges = []
|
| 205 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
| 206 |
+
for p in s.split(','):
|
| 207 |
+
m = range_re.match(p)
|
| 208 |
+
if m:
|
| 209 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
| 210 |
+
else:
|
| 211 |
+
ranges.append(int(p))
|
| 212 |
+
return ranges
|
| 213 |
+
|
| 214 |
+
#----------------------------------------------------------------------------
|
| 215 |
+
|
| 216 |
+
@click.command()
|
| 217 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
|
| 218 |
+
@click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
|
| 219 |
+
@click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
|
| 220 |
+
@click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
|
| 221 |
+
@click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
|
| 222 |
+
@click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
|
| 223 |
+
|
| 224 |
+
@click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
|
| 225 |
+
@click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
|
| 226 |
+
@click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
|
| 227 |
+
@click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
|
| 228 |
+
@click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
|
| 229 |
+
@click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
|
| 230 |
+
@click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
|
| 231 |
+
@click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
|
| 232 |
+
|
| 233 |
+
@click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
|
| 234 |
+
@click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
|
| 235 |
+
@click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
|
| 236 |
+
@click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
|
| 237 |
+
|
| 238 |
+
def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
|
| 239 |
+
"""Generate random images using the techniques described in the paper
|
| 240 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models".
|
| 241 |
+
|
| 242 |
+
Examples:
|
| 243 |
+
|
| 244 |
+
\b
|
| 245 |
+
# Generate 64 images and save them as out/*.png
|
| 246 |
+
python generate.py --outdir=out --seeds=0-63 --batch=64 \\
|
| 247 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 248 |
+
|
| 249 |
+
\b
|
| 250 |
+
# Generate 1024 images using 2 GPUs
|
| 251 |
+
torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
|
| 252 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
|
| 253 |
+
"""
|
| 254 |
+
dist.init()
|
| 255 |
+
num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
|
| 256 |
+
all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
|
| 257 |
+
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
|
| 258 |
+
|
| 259 |
+
# Rank 0 goes first.
|
| 260 |
+
if dist.get_rank() != 0:
|
| 261 |
+
torch.distributed.barrier()
|
| 262 |
+
|
| 263 |
+
# Load network.
|
| 264 |
+
dist.print0(f'Loading network from "{network_pkl}"...')
|
| 265 |
+
with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
|
| 266 |
+
net = pickle.load(f)['ema'].to(device)
|
| 267 |
+
|
| 268 |
+
# Other ranks follow.
|
| 269 |
+
if dist.get_rank() == 0:
|
| 270 |
+
torch.distributed.barrier()
|
| 271 |
+
|
| 272 |
+
# Loop over batches.
|
| 273 |
+
dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
|
| 274 |
+
for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
|
| 275 |
+
torch.distributed.barrier()
|
| 276 |
+
batch_size = len(batch_seeds)
|
| 277 |
+
if batch_size == 0:
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
# Pick latents and labels.
|
| 281 |
+
rnd = StackedRandomGenerator(device, batch_seeds)
|
| 282 |
+
latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
|
| 283 |
+
class_labels = None
|
| 284 |
+
if net.label_dim:
|
| 285 |
+
class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
|
| 286 |
+
if class_idx is not None:
|
| 287 |
+
class_labels[:, :] = 0
|
| 288 |
+
class_labels[:, class_idx] = 1
|
| 289 |
+
|
| 290 |
+
# Generate images.
|
| 291 |
+
sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
|
| 292 |
+
have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
|
| 293 |
+
sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
|
| 294 |
+
images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
|
| 295 |
+
|
| 296 |
+
# Save images.
|
| 297 |
+
images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
| 298 |
+
for seed, image_np in zip(batch_seeds, images_np):
|
| 299 |
+
image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
|
| 300 |
+
os.makedirs(image_dir, exist_ok=True)
|
| 301 |
+
image_path = os.path.join(image_dir, f'{seed:06d}.png')
|
| 302 |
+
if image_np.shape[2] == 1:
|
| 303 |
+
PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
|
| 304 |
+
else:
|
| 305 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
| 306 |
+
|
| 307 |
+
# Done.
|
| 308 |
+
torch.distributed.barrier()
|
| 309 |
+
dist.print0('Done.')
|
| 310 |
+
|
| 311 |
+
#----------------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
main()
|
| 315 |
+
|
| 316 |
+
#----------------------------------------------------------------------------
|
edm/torch_utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
# empty
|
edm/torch_utils/distributed.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
from . import training_stats
|
| 11 |
+
|
| 12 |
+
#----------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
def init():
|
| 15 |
+
if 'MASTER_ADDR' not in os.environ:
|
| 16 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 17 |
+
if 'MASTER_PORT' not in os.environ:
|
| 18 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 19 |
+
if 'RANK' not in os.environ:
|
| 20 |
+
os.environ['RANK'] = '0'
|
| 21 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 22 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 23 |
+
if 'WORLD_SIZE' not in os.environ:
|
| 24 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 25 |
+
|
| 26 |
+
backend = 'gloo' if os.name == 'nt' else 'nccl'
|
| 27 |
+
torch.distributed.init_process_group(backend=backend, init_method='env://')
|
| 28 |
+
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
|
| 29 |
+
|
| 30 |
+
sync_device = torch.device('cuda') if get_world_size() > 1 else None
|
| 31 |
+
training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
|
| 32 |
+
|
| 33 |
+
#----------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def get_rank():
|
| 36 |
+
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 37 |
+
|
| 38 |
+
#----------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def get_world_size():
|
| 41 |
+
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
| 42 |
+
|
| 43 |
+
#----------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
def should_stop():
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
#----------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def update_progress(cur, total):
|
| 51 |
+
_ = cur, total
|
| 52 |
+
|
| 53 |
+
#----------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
def print0(*args, **kwargs):
|
| 56 |
+
if get_rank() == 0:
|
| 57 |
+
print(*args, **kwargs)
|
| 58 |
+
|
| 59 |
+
#----------------------------------------------------------------------------
|
edm/torch_utils/misc.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import contextlib
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import warnings
|
| 13 |
+
import edm.dnnlib as dnnlib
|
| 14 |
+
|
| 15 |
+
#----------------------------------------------------------------------------
|
| 16 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
| 17 |
+
# same constant is used multiple times.
|
| 18 |
+
|
| 19 |
+
_constant_cache = dict()
|
| 20 |
+
|
| 21 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
| 22 |
+
value = np.asarray(value)
|
| 23 |
+
if shape is not None:
|
| 24 |
+
shape = tuple(shape)
|
| 25 |
+
if dtype is None:
|
| 26 |
+
dtype = torch.get_default_dtype()
|
| 27 |
+
if device is None:
|
| 28 |
+
device = torch.device('cpu')
|
| 29 |
+
if memory_format is None:
|
| 30 |
+
memory_format = torch.contiguous_format
|
| 31 |
+
|
| 32 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
| 33 |
+
tensor = _constant_cache.get(key, None)
|
| 34 |
+
if tensor is None:
|
| 35 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
| 36 |
+
if shape is not None:
|
| 37 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
| 38 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
| 39 |
+
_constant_cache[key] = tensor
|
| 40 |
+
return tensor
|
| 41 |
+
|
| 42 |
+
#----------------------------------------------------------------------------
|
| 43 |
+
# Replace NaN/Inf with specified numerical values.
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
| 47 |
+
except AttributeError:
|
| 48 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
| 49 |
+
assert isinstance(input, torch.Tensor)
|
| 50 |
+
if posinf is None:
|
| 51 |
+
posinf = torch.finfo(input.dtype).max
|
| 52 |
+
if neginf is None:
|
| 53 |
+
neginf = torch.finfo(input.dtype).min
|
| 54 |
+
assert nan == 0
|
| 55 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
| 56 |
+
|
| 57 |
+
#----------------------------------------------------------------------------
|
| 58 |
+
# Symbolic assert.
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
| 62 |
+
except AttributeError:
|
| 63 |
+
symbolic_assert = torch.Assert # 1.7.0
|
| 64 |
+
|
| 65 |
+
#----------------------------------------------------------------------------
|
| 66 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
| 67 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
| 68 |
+
|
| 69 |
+
@contextlib.contextmanager
|
| 70 |
+
def suppress_tracer_warnings():
|
| 71 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
| 72 |
+
warnings.filters.insert(0, flt)
|
| 73 |
+
yield
|
| 74 |
+
warnings.filters.remove(flt)
|
| 75 |
+
|
| 76 |
+
#----------------------------------------------------------------------------
|
| 77 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
| 78 |
+
# None indicates that the size of a dimension is allowed to vary.
|
| 79 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
| 80 |
+
|
| 81 |
+
def assert_shape(tensor, ref_shape):
|
| 82 |
+
if tensor.ndim != len(ref_shape):
|
| 83 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
| 84 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
| 85 |
+
if ref_size is None:
|
| 86 |
+
pass
|
| 87 |
+
elif isinstance(ref_size, torch.Tensor):
|
| 88 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 89 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
| 90 |
+
elif isinstance(size, torch.Tensor):
|
| 91 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 92 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
| 93 |
+
elif size != ref_size:
|
| 94 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
| 95 |
+
|
| 96 |
+
#----------------------------------------------------------------------------
|
| 97 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
| 98 |
+
|
| 99 |
+
def profiled_function(fn):
|
| 100 |
+
def decorator(*args, **kwargs):
|
| 101 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
| 102 |
+
return fn(*args, **kwargs)
|
| 103 |
+
decorator.__name__ = fn.__name__
|
| 104 |
+
return decorator
|
| 105 |
+
|
| 106 |
+
#----------------------------------------------------------------------------
|
| 107 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
| 108 |
+
# indefinitely, shuffling items as it goes.
|
| 109 |
+
|
| 110 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
| 111 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
| 112 |
+
assert len(dataset) > 0
|
| 113 |
+
assert num_replicas > 0
|
| 114 |
+
assert 0 <= rank < num_replicas
|
| 115 |
+
assert 0 <= window_size <= 1
|
| 116 |
+
super().__init__(dataset)
|
| 117 |
+
self.dataset = dataset
|
| 118 |
+
self.rank = rank
|
| 119 |
+
self.num_replicas = num_replicas
|
| 120 |
+
self.shuffle = shuffle
|
| 121 |
+
self.seed = seed
|
| 122 |
+
self.window_size = window_size
|
| 123 |
+
|
| 124 |
+
def __iter__(self):
|
| 125 |
+
order = np.arange(len(self.dataset))
|
| 126 |
+
rnd = None
|
| 127 |
+
window = 0
|
| 128 |
+
if self.shuffle:
|
| 129 |
+
rnd = np.random.RandomState(self.seed)
|
| 130 |
+
rnd.shuffle(order)
|
| 131 |
+
window = int(np.rint(order.size * self.window_size))
|
| 132 |
+
|
| 133 |
+
idx = 0
|
| 134 |
+
while True:
|
| 135 |
+
i = idx % order.size
|
| 136 |
+
if idx % self.num_replicas == self.rank:
|
| 137 |
+
yield order[i]
|
| 138 |
+
if window >= 2:
|
| 139 |
+
j = (i - rnd.randint(window)) % order.size
|
| 140 |
+
order[i], order[j] = order[j], order[i]
|
| 141 |
+
idx += 1
|
| 142 |
+
|
| 143 |
+
#----------------------------------------------------------------------------
|
| 144 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
| 145 |
+
|
| 146 |
+
def params_and_buffers(module):
|
| 147 |
+
assert isinstance(module, torch.nn.Module)
|
| 148 |
+
return list(module.parameters()) + list(module.buffers())
|
| 149 |
+
|
| 150 |
+
def named_params_and_buffers(module):
|
| 151 |
+
assert isinstance(module, torch.nn.Module)
|
| 152 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
| 156 |
+
assert isinstance(src_module, torch.nn.Module)
|
| 157 |
+
assert isinstance(dst_module, torch.nn.Module)
|
| 158 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
| 159 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
| 160 |
+
assert (name in src_tensors) or (not require_all)
|
| 161 |
+
if name in src_tensors:
|
| 162 |
+
tensor.copy_(src_tensors[name])
|
| 163 |
+
|
| 164 |
+
#----------------------------------------------------------------------------
|
| 165 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
| 166 |
+
# synchronization.
|
| 167 |
+
|
| 168 |
+
@contextlib.contextmanager
|
| 169 |
+
def ddp_sync(module, sync):
|
| 170 |
+
assert isinstance(module, torch.nn.Module)
|
| 171 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
| 172 |
+
yield
|
| 173 |
+
else:
|
| 174 |
+
with module.no_sync():
|
| 175 |
+
yield
|
| 176 |
+
|
| 177 |
+
#----------------------------------------------------------------------------
|
| 178 |
+
# Check DistributedDataParallel consistency across processes.
|
| 179 |
+
|
| 180 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
| 181 |
+
assert isinstance(module, torch.nn.Module)
|
| 182 |
+
for name, tensor in named_params_and_buffers(module):
|
| 183 |
+
fullname = type(module).__name__ + '.' + name
|
| 184 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
| 185 |
+
continue
|
| 186 |
+
tensor = tensor.detach()
|
| 187 |
+
if tensor.is_floating_point():
|
| 188 |
+
tensor = nan_to_num(tensor)
|
| 189 |
+
other = tensor.clone()
|
| 190 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 191 |
+
assert (tensor == other).all(), fullname
|
| 192 |
+
|
| 193 |
+
#----------------------------------------------------------------------------
|
| 194 |
+
# Print summary table of module hierarchy.
|
| 195 |
+
|
| 196 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
| 197 |
+
assert isinstance(module, torch.nn.Module)
|
| 198 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
| 199 |
+
assert isinstance(inputs, (tuple, list))
|
| 200 |
+
|
| 201 |
+
# Register hooks.
|
| 202 |
+
entries = []
|
| 203 |
+
nesting = [0]
|
| 204 |
+
def pre_hook(_mod, _inputs):
|
| 205 |
+
nesting[0] += 1
|
| 206 |
+
def post_hook(mod, _inputs, outputs):
|
| 207 |
+
nesting[0] -= 1
|
| 208 |
+
if nesting[0] <= max_nesting:
|
| 209 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
| 210 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
| 211 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
| 212 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
| 213 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
| 214 |
+
|
| 215 |
+
# Run module.
|
| 216 |
+
outputs = module(*inputs)
|
| 217 |
+
for hook in hooks:
|
| 218 |
+
hook.remove()
|
| 219 |
+
|
| 220 |
+
# Identify unique outputs, parameters, and buffers.
|
| 221 |
+
tensors_seen = set()
|
| 222 |
+
for e in entries:
|
| 223 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
| 224 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
| 225 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
| 226 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
| 227 |
+
|
| 228 |
+
# Filter out redundant entries.
|
| 229 |
+
if skip_redundant:
|
| 230 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
| 231 |
+
|
| 232 |
+
# Construct table.
|
| 233 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
| 234 |
+
rows += [['---'] * len(rows[0])]
|
| 235 |
+
param_total = 0
|
| 236 |
+
buffer_total = 0
|
| 237 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
| 238 |
+
for e in entries:
|
| 239 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
| 240 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
| 241 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
| 242 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
| 243 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
| 244 |
+
rows += [[
|
| 245 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
| 246 |
+
str(param_size) if param_size else '-',
|
| 247 |
+
str(buffer_size) if buffer_size else '-',
|
| 248 |
+
(output_shapes + ['-'])[0],
|
| 249 |
+
(output_dtypes + ['-'])[0],
|
| 250 |
+
]]
|
| 251 |
+
for idx in range(1, len(e.outputs)):
|
| 252 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
| 253 |
+
param_total += param_size
|
| 254 |
+
buffer_total += buffer_size
|
| 255 |
+
rows += [['---'] * len(rows[0])]
|
| 256 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
| 257 |
+
|
| 258 |
+
# Print table.
|
| 259 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
| 260 |
+
print()
|
| 261 |
+
for row in rows:
|
| 262 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
| 263 |
+
print()
|
| 264 |
+
return outputs
|
| 265 |
+
|
| 266 |
+
#----------------------------------------------------------------------------
|
edm/torch_utils/persistence.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for pickling Python code alongside other data.
|
| 9 |
+
|
| 10 |
+
The pickled code is automatically imported into a separate Python module
|
| 11 |
+
during unpickling. This way, any previously exported pickles will remain
|
| 12 |
+
usable even if the original code is no longer available, or if the current
|
| 13 |
+
version of the code is not consistent with what was originally pickled."""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
import pickle
|
| 17 |
+
import io
|
| 18 |
+
import inspect
|
| 19 |
+
import copy
|
| 20 |
+
import uuid
|
| 21 |
+
import types
|
| 22 |
+
import edm.dnnlib as dnnlib
|
| 23 |
+
|
| 24 |
+
#----------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
_version = 6 # internal version number
|
| 27 |
+
_decorators = set() # {decorator_class, ...}
|
| 28 |
+
_import_hooks = [] # [hook_function, ...]
|
| 29 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
| 30 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
| 31 |
+
|
| 32 |
+
#----------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def persistent_class(orig_class):
|
| 35 |
+
r"""Class decorator that extends a given class to save its source code
|
| 36 |
+
when pickled.
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
from torch_utils import persistence
|
| 41 |
+
|
| 42 |
+
@persistence.persistent_class
|
| 43 |
+
class MyNetwork(torch.nn.Module):
|
| 44 |
+
def __init__(self, num_inputs, num_outputs):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
@persistence.persistent_class
|
| 50 |
+
class MyLayer(torch.nn.Module):
|
| 51 |
+
...
|
| 52 |
+
|
| 53 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
| 54 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
| 55 |
+
and submodules). This way, any previously exported pickle will remain
|
| 56 |
+
usable even if the class definitions have been modified or are no
|
| 57 |
+
longer available.
|
| 58 |
+
|
| 59 |
+
The decorator saves the source code of the entire Python module
|
| 60 |
+
containing the decorated class. It does *not* save the source code of
|
| 61 |
+
any imported modules. Thus, the imported modules must be available
|
| 62 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
| 63 |
+
|
| 64 |
+
It is ok to call functions defined in the same module from the
|
| 65 |
+
decorated class. However, if the decorated class depends on other
|
| 66 |
+
classes defined in the same module, they must be decorated as well.
|
| 67 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
| 68 |
+
|
| 69 |
+
It is also possible to employ the decorator just-in-time before
|
| 70 |
+
calling the constructor. For example:
|
| 71 |
+
|
| 72 |
+
cls = MyLayer
|
| 73 |
+
if want_to_make_it_persistent:
|
| 74 |
+
cls = persistence.persistent_class(cls)
|
| 75 |
+
layer = cls(num_inputs, num_outputs)
|
| 76 |
+
|
| 77 |
+
As an additional feature, the decorator also keeps track of the
|
| 78 |
+
arguments that were used to construct each instance of the decorated
|
| 79 |
+
class. The arguments can be queried via `obj.init_args` and
|
| 80 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
| 81 |
+
object state. This feature can be disabled on a per-instance basis
|
| 82 |
+
by setting `self._record_init_args = False` in the constructor.
|
| 83 |
+
|
| 84 |
+
A typical use case is to first unpickle a previous instance of a
|
| 85 |
+
persistent class, and then upgrade it to use the latest version of
|
| 86 |
+
the source code:
|
| 87 |
+
|
| 88 |
+
with open('old_pickle.pkl', 'rb') as f:
|
| 89 |
+
old_net = pickle.load(f)
|
| 90 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
| 91 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
| 92 |
+
"""
|
| 93 |
+
assert isinstance(orig_class, type)
|
| 94 |
+
if is_persistent(orig_class):
|
| 95 |
+
return orig_class
|
| 96 |
+
|
| 97 |
+
assert orig_class.__module__ in sys.modules
|
| 98 |
+
orig_module = sys.modules[orig_class.__module__]
|
| 99 |
+
orig_module_src = _module_to_src(orig_module)
|
| 100 |
+
|
| 101 |
+
class Decorator(orig_class):
|
| 102 |
+
_orig_module_src = orig_module_src
|
| 103 |
+
_orig_class_name = orig_class.__name__
|
| 104 |
+
|
| 105 |
+
def __init__(self, *args, **kwargs):
|
| 106 |
+
super().__init__(*args, **kwargs)
|
| 107 |
+
record_init_args = getattr(self, '_record_init_args', True)
|
| 108 |
+
self._init_args = copy.deepcopy(args) if record_init_args else None
|
| 109 |
+
self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
|
| 110 |
+
assert orig_class.__name__ in orig_module.__dict__
|
| 111 |
+
_check_pickleable(self.__reduce__())
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def init_args(self):
|
| 115 |
+
assert self._init_args is not None
|
| 116 |
+
return copy.deepcopy(self._init_args)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def init_kwargs(self):
|
| 120 |
+
assert self._init_kwargs is not None
|
| 121 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
| 122 |
+
|
| 123 |
+
def __reduce__(self):
|
| 124 |
+
fields = list(super().__reduce__())
|
| 125 |
+
fields += [None] * max(3 - len(fields), 0)
|
| 126 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
| 127 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
| 128 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
| 129 |
+
fields[1] = (meta,) # reconstruct args
|
| 130 |
+
fields[2] = None # state dict
|
| 131 |
+
return tuple(fields)
|
| 132 |
+
|
| 133 |
+
Decorator.__name__ = orig_class.__name__
|
| 134 |
+
Decorator.__module__ = orig_class.__module__
|
| 135 |
+
_decorators.add(Decorator)
|
| 136 |
+
return Decorator
|
| 137 |
+
|
| 138 |
+
#----------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def is_persistent(obj):
|
| 141 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
| 142 |
+
whether it will save its source code when pickled.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
if obj in _decorators:
|
| 146 |
+
return True
|
| 147 |
+
except TypeError:
|
| 148 |
+
pass
|
| 149 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
| 150 |
+
|
| 151 |
+
#----------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def import_hook(hook):
|
| 154 |
+
r"""Register an import hook that is called whenever a persistent object
|
| 155 |
+
is being unpickled. A typical use case is to patch the pickled source
|
| 156 |
+
code to avoid errors and inconsistencies when the API of some imported
|
| 157 |
+
module has changed.
|
| 158 |
+
|
| 159 |
+
The hook should have the following signature:
|
| 160 |
+
|
| 161 |
+
hook(meta) -> modified meta
|
| 162 |
+
|
| 163 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
| 164 |
+
|
| 165 |
+
type: Type of the persistent object, e.g. `'class'`.
|
| 166 |
+
version: Internal version number of `torch_utils.persistence`.
|
| 167 |
+
module_src Original source code of the Python module.
|
| 168 |
+
class_name: Class name in the original Python module.
|
| 169 |
+
state: Internal state of the object.
|
| 170 |
+
|
| 171 |
+
Example:
|
| 172 |
+
|
| 173 |
+
@persistence.import_hook
|
| 174 |
+
def wreck_my_network(meta):
|
| 175 |
+
if meta.class_name == 'MyNetwork':
|
| 176 |
+
print('MyNetwork is being imported. I will wreck it!')
|
| 177 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
| 178 |
+
return meta
|
| 179 |
+
"""
|
| 180 |
+
assert callable(hook)
|
| 181 |
+
_import_hooks.append(hook)
|
| 182 |
+
|
| 183 |
+
#----------------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
def _reconstruct_persistent_obj(meta):
|
| 186 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
| 187 |
+
a persistent object.
|
| 188 |
+
"""
|
| 189 |
+
meta = dnnlib.EasyDict(meta)
|
| 190 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
| 191 |
+
for hook in _import_hooks:
|
| 192 |
+
meta = hook(meta)
|
| 193 |
+
assert meta is not None
|
| 194 |
+
|
| 195 |
+
assert meta.version == _version
|
| 196 |
+
module = _src_to_module(meta.module_src)
|
| 197 |
+
|
| 198 |
+
assert meta.type == 'class'
|
| 199 |
+
orig_class = module.__dict__[meta.class_name]
|
| 200 |
+
decorator_class = persistent_class(orig_class)
|
| 201 |
+
obj = decorator_class.__new__(decorator_class)
|
| 202 |
+
|
| 203 |
+
setstate = getattr(obj, '__setstate__', None)
|
| 204 |
+
if callable(setstate):
|
| 205 |
+
setstate(meta.state) # pylint: disable=not-callable
|
| 206 |
+
else:
|
| 207 |
+
obj.__dict__.update(meta.state)
|
| 208 |
+
return obj
|
| 209 |
+
|
| 210 |
+
#----------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def _module_to_src(module):
|
| 213 |
+
r"""Query the source code of a given Python module.
|
| 214 |
+
"""
|
| 215 |
+
src = _module_to_src_dict.get(module, None)
|
| 216 |
+
if src is None:
|
| 217 |
+
src = inspect.getsource(module)
|
| 218 |
+
_module_to_src_dict[module] = src
|
| 219 |
+
_src_to_module_dict[src] = module
|
| 220 |
+
return src
|
| 221 |
+
|
| 222 |
+
def _src_to_module(src):
|
| 223 |
+
r"""Get or create a Python module for the given source code.
|
| 224 |
+
"""
|
| 225 |
+
module = _src_to_module_dict.get(src, None)
|
| 226 |
+
if module is None:
|
| 227 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
| 228 |
+
module = types.ModuleType(module_name)
|
| 229 |
+
sys.modules[module_name] = module
|
| 230 |
+
_module_to_src_dict[module] = src
|
| 231 |
+
_src_to_module_dict[src] = module
|
| 232 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
| 233 |
+
return module
|
| 234 |
+
|
| 235 |
+
#----------------------------------------------------------------------------
|
| 236 |
+
|
| 237 |
+
def _check_pickleable(obj):
|
| 238 |
+
r"""Check that the given object is pickleable, raising an exception if
|
| 239 |
+
it is not. This function is expected to be considerably more efficient
|
| 240 |
+
than actually pickling the object.
|
| 241 |
+
"""
|
| 242 |
+
def recurse(obj):
|
| 243 |
+
if isinstance(obj, (list, tuple, set)):
|
| 244 |
+
return [recurse(x) for x in obj]
|
| 245 |
+
if isinstance(obj, dict):
|
| 246 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
| 247 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
| 248 |
+
return None # Python primitive types are pickleable.
|
| 249 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
| 250 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
| 251 |
+
if is_persistent(obj):
|
| 252 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
| 253 |
+
return obj
|
| 254 |
+
with io.BytesIO() as f:
|
| 255 |
+
pickle.dump(recurse(obj), f)
|
| 256 |
+
|
| 257 |
+
#----------------------------------------------------------------------------
|
edm/torch_utils/training_stats.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for reporting and collecting training statistics across
|
| 9 |
+
multiple processes and devices. The interface is designed to minimize
|
| 10 |
+
synchronization overhead as well as the amount of boilerplate in user
|
| 11 |
+
code."""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import edm.dnnlib as dnnlib
|
| 17 |
+
|
| 18 |
+
from . import misc
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
| 23 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
| 24 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
| 25 |
+
_rank = 0 # Rank of the current process.
|
| 26 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
| 27 |
+
_sync_called = False # Has _sync() been called yet?
|
| 28 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
| 29 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
| 30 |
+
|
| 31 |
+
#----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def init_multiprocessing(rank, sync_device):
|
| 34 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
| 35 |
+
across multiple processes.
|
| 36 |
+
|
| 37 |
+
This function must be called after
|
| 38 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
| 39 |
+
The call is not necessary if multi-process collection is not needed.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
rank: Rank of the current process.
|
| 43 |
+
sync_device: PyTorch device to use for inter-process
|
| 44 |
+
communication, or None to disable multi-process
|
| 45 |
+
collection. Typically `torch.device('cuda', rank)`.
|
| 46 |
+
"""
|
| 47 |
+
global _rank, _sync_device
|
| 48 |
+
assert not _sync_called
|
| 49 |
+
_rank = rank
|
| 50 |
+
_sync_device = sync_device
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
@misc.profiled_function
|
| 55 |
+
def report(name, value):
|
| 56 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
| 57 |
+
`Collector`, across device and process boundaries.
|
| 58 |
+
|
| 59 |
+
This function is expected to be extremely cheap and can be safely
|
| 60 |
+
called from anywhere in the training loop, loss function, or inside a
|
| 61 |
+
`torch.nn.Module`.
|
| 62 |
+
|
| 63 |
+
Warning: The current implementation expects the set of unique names to
|
| 64 |
+
be consistent across processes. Please make sure that `report()` is
|
| 65 |
+
called at least once for each unique name by each process, and in the
|
| 66 |
+
same order. If a given process has no scalars to broadcast, it can do
|
| 67 |
+
`report(name, [])` (empty list).
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
name: Arbitrary string specifying the name of the statistic.
|
| 71 |
+
Averages are accumulated separately for each unique name.
|
| 72 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
| 73 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
The same `value` that was passed in.
|
| 77 |
+
"""
|
| 78 |
+
if name not in _counters:
|
| 79 |
+
_counters[name] = dict()
|
| 80 |
+
|
| 81 |
+
elems = torch.as_tensor(value)
|
| 82 |
+
if elems.numel() == 0:
|
| 83 |
+
return value
|
| 84 |
+
|
| 85 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
| 86 |
+
moments = torch.stack([
|
| 87 |
+
torch.ones_like(elems).sum(),
|
| 88 |
+
elems.sum(),
|
| 89 |
+
elems.square().sum(),
|
| 90 |
+
])
|
| 91 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
| 92 |
+
moments = moments.to(_counter_dtype)
|
| 93 |
+
|
| 94 |
+
device = moments.device
|
| 95 |
+
if device not in _counters[name]:
|
| 96 |
+
_counters[name][device] = torch.zeros_like(moments)
|
| 97 |
+
_counters[name][device].add_(moments)
|
| 98 |
+
return value
|
| 99 |
+
|
| 100 |
+
#----------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
def report0(name, value):
|
| 103 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
| 104 |
+
but ignores any scalars provided by the other processes.
|
| 105 |
+
See `report()` for further details.
|
| 106 |
+
"""
|
| 107 |
+
report(name, value if _rank == 0 else [])
|
| 108 |
+
return value
|
| 109 |
+
|
| 110 |
+
#----------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
class Collector:
|
| 113 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
| 114 |
+
computes their long-term averages (mean and standard deviation) over
|
| 115 |
+
user-defined periods of time.
|
| 116 |
+
|
| 117 |
+
The averages are first collected into internal counters that are not
|
| 118 |
+
directly visible to the user. They are then copied to the user-visible
|
| 119 |
+
state as a result of calling `update()` and can then be queried using
|
| 120 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
| 121 |
+
internal counters for the next round, so that the user-visible state
|
| 122 |
+
effectively reflects averages collected between the last two calls to
|
| 123 |
+
`update()`.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
regex: Regular expression defining which statistics to
|
| 127 |
+
collect. The default is to collect everything.
|
| 128 |
+
keep_previous: Whether to retain the previous averages if no
|
| 129 |
+
scalars were collected on a given round
|
| 130 |
+
(default: True).
|
| 131 |
+
"""
|
| 132 |
+
def __init__(self, regex='.*', keep_previous=True):
|
| 133 |
+
self._regex = re.compile(regex)
|
| 134 |
+
self._keep_previous = keep_previous
|
| 135 |
+
self._cumulative = dict()
|
| 136 |
+
self._moments = dict()
|
| 137 |
+
self.update()
|
| 138 |
+
self._moments.clear()
|
| 139 |
+
|
| 140 |
+
def names(self):
|
| 141 |
+
r"""Returns the names of all statistics broadcasted so far that
|
| 142 |
+
match the regular expression specified at construction time.
|
| 143 |
+
"""
|
| 144 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
| 145 |
+
|
| 146 |
+
def update(self):
|
| 147 |
+
r"""Copies current values of the internal counters to the
|
| 148 |
+
user-visible state and resets them for the next round.
|
| 149 |
+
|
| 150 |
+
If `keep_previous=True` was specified at construction time, the
|
| 151 |
+
operation is skipped for statistics that have received no scalars
|
| 152 |
+
since the last update, retaining their previous averages.
|
| 153 |
+
|
| 154 |
+
This method performs a number of GPU-to-CPU transfers and one
|
| 155 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
| 156 |
+
periodically in the main training loop, typically once every
|
| 157 |
+
N training steps.
|
| 158 |
+
"""
|
| 159 |
+
if not self._keep_previous:
|
| 160 |
+
self._moments.clear()
|
| 161 |
+
for name, cumulative in _sync(self.names()):
|
| 162 |
+
if name not in self._cumulative:
|
| 163 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 164 |
+
delta = cumulative - self._cumulative[name]
|
| 165 |
+
self._cumulative[name].copy_(cumulative)
|
| 166 |
+
if float(delta[0]) != 0:
|
| 167 |
+
self._moments[name] = delta
|
| 168 |
+
|
| 169 |
+
def _get_delta(self, name):
|
| 170 |
+
r"""Returns the raw moments that were accumulated for the given
|
| 171 |
+
statistic between the last two calls to `update()`, or zero if
|
| 172 |
+
no scalars were collected.
|
| 173 |
+
"""
|
| 174 |
+
assert self._regex.fullmatch(name)
|
| 175 |
+
if name not in self._moments:
|
| 176 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 177 |
+
return self._moments[name]
|
| 178 |
+
|
| 179 |
+
def num(self, name):
|
| 180 |
+
r"""Returns the number of scalars that were accumulated for the given
|
| 181 |
+
statistic between the last two calls to `update()`, or zero if
|
| 182 |
+
no scalars were collected.
|
| 183 |
+
"""
|
| 184 |
+
delta = self._get_delta(name)
|
| 185 |
+
return int(delta[0])
|
| 186 |
+
|
| 187 |
+
def mean(self, name):
|
| 188 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
| 189 |
+
given statistic between the last two calls to `update()`, or NaN if
|
| 190 |
+
no scalars were collected.
|
| 191 |
+
"""
|
| 192 |
+
delta = self._get_delta(name)
|
| 193 |
+
if int(delta[0]) == 0:
|
| 194 |
+
return float('nan')
|
| 195 |
+
return float(delta[1] / delta[0])
|
| 196 |
+
|
| 197 |
+
def std(self, name):
|
| 198 |
+
r"""Returns the standard deviation of the scalars that were
|
| 199 |
+
accumulated for the given statistic between the last two calls to
|
| 200 |
+
`update()`, or NaN if no scalars were collected.
|
| 201 |
+
"""
|
| 202 |
+
delta = self._get_delta(name)
|
| 203 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
| 204 |
+
return float('nan')
|
| 205 |
+
if int(delta[0]) == 1:
|
| 206 |
+
return float(0)
|
| 207 |
+
mean = float(delta[1] / delta[0])
|
| 208 |
+
raw_var = float(delta[2] / delta[0])
|
| 209 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
| 210 |
+
|
| 211 |
+
def as_dict(self):
|
| 212 |
+
r"""Returns the averages accumulated between the last two calls to
|
| 213 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
| 214 |
+
|
| 215 |
+
dnnlib.EasyDict(
|
| 216 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
| 217 |
+
...
|
| 218 |
+
)
|
| 219 |
+
"""
|
| 220 |
+
stats = dnnlib.EasyDict()
|
| 221 |
+
for name in self.names():
|
| 222 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
| 223 |
+
return stats
|
| 224 |
+
|
| 225 |
+
def __getitem__(self, name):
|
| 226 |
+
r"""Convenience getter.
|
| 227 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
| 228 |
+
"""
|
| 229 |
+
return self.mean(name)
|
| 230 |
+
|
| 231 |
+
#----------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
def _sync(names):
|
| 234 |
+
r"""Synchronize the global cumulative counters across devices and
|
| 235 |
+
processes. Called internally by `Collector.update()`.
|
| 236 |
+
"""
|
| 237 |
+
if len(names) == 0:
|
| 238 |
+
return []
|
| 239 |
+
global _sync_called
|
| 240 |
+
_sync_called = True
|
| 241 |
+
|
| 242 |
+
# Collect deltas within current rank.
|
| 243 |
+
deltas = []
|
| 244 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
| 245 |
+
for name in names:
|
| 246 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
| 247 |
+
for counter in _counters[name].values():
|
| 248 |
+
delta.add_(counter.to(device))
|
| 249 |
+
counter.copy_(torch.zeros_like(counter))
|
| 250 |
+
deltas.append(delta)
|
| 251 |
+
deltas = torch.stack(deltas)
|
| 252 |
+
|
| 253 |
+
# Sum deltas across ranks.
|
| 254 |
+
if _sync_device is not None:
|
| 255 |
+
torch.distributed.all_reduce(deltas)
|
| 256 |
+
|
| 257 |
+
# Update cumulative values.
|
| 258 |
+
deltas = deltas.cpu()
|
| 259 |
+
for idx, name in enumerate(names):
|
| 260 |
+
if name not in _cumulative:
|
| 261 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 262 |
+
_cumulative[name].add_(deltas[idx])
|
| 263 |
+
|
| 264 |
+
# Return name-value pairs.
|
| 265 |
+
return [(name, _cumulative[name]) for name in names]
|
| 266 |
+
|
| 267 |
+
#----------------------------------------------------------------------------
|
| 268 |
+
# Convenience.
|
| 269 |
+
|
| 270 |
+
default_collector = Collector()
|
| 271 |
+
|
| 272 |
+
#----------------------------------------------------------------------------
|
edm/train.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Train diffusion-based generative model using the techniques described in the
|
| 9 |
+
paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import json
|
| 14 |
+
import click
|
| 15 |
+
import torch
|
| 16 |
+
import dnnlib
|
| 17 |
+
from torch_utils import distributed as dist
|
| 18 |
+
from training import training_loop
|
| 19 |
+
|
| 20 |
+
import warnings
|
| 21 |
+
warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
|
| 22 |
+
|
| 23 |
+
#----------------------------------------------------------------------------
|
| 24 |
+
# Parse a comma separated list of numbers or ranges and return a list of ints.
|
| 25 |
+
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
|
| 26 |
+
|
| 27 |
+
def parse_int_list(s):
|
| 28 |
+
if isinstance(s, list): return s
|
| 29 |
+
ranges = []
|
| 30 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
| 31 |
+
for p in s.split(','):
|
| 32 |
+
m = range_re.match(p)
|
| 33 |
+
if m:
|
| 34 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
| 35 |
+
else:
|
| 36 |
+
ranges.append(int(p))
|
| 37 |
+
return ranges
|
| 38 |
+
|
| 39 |
+
#----------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
@click.command()
|
| 42 |
+
|
| 43 |
+
# Main options.
|
| 44 |
+
@click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
|
| 45 |
+
@click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
|
| 46 |
+
@click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
|
| 47 |
+
@click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
|
| 48 |
+
@click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
|
| 49 |
+
|
| 50 |
+
# Hyperparameters.
|
| 51 |
+
@click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
|
| 52 |
+
@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
|
| 53 |
+
@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
|
| 54 |
+
@click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
|
| 55 |
+
@click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
|
| 56 |
+
@click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
|
| 57 |
+
@click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
|
| 58 |
+
@click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
|
| 59 |
+
@click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
|
| 60 |
+
@click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
|
| 61 |
+
|
| 62 |
+
# Performance-related.
|
| 63 |
+
@click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
|
| 64 |
+
@click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
|
| 65 |
+
@click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
|
| 66 |
+
@click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
|
| 67 |
+
@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
|
| 68 |
+
|
| 69 |
+
# I/O-related.
|
| 70 |
+
@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
|
| 71 |
+
@click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
|
| 72 |
+
@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
|
| 73 |
+
@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
|
| 74 |
+
@click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
|
| 75 |
+
@click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
|
| 76 |
+
@click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
|
| 77 |
+
@click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
|
| 78 |
+
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
|
| 79 |
+
|
| 80 |
+
def main(**kwargs):
|
| 81 |
+
"""Train diffusion-based generative model using the techniques described in the
|
| 82 |
+
paper "Elucidating the Design Space of Diffusion-Based Generative Models".
|
| 83 |
+
|
| 84 |
+
Examples:
|
| 85 |
+
|
| 86 |
+
\b
|
| 87 |
+
# Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
|
| 88 |
+
torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
|
| 89 |
+
--data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
|
| 90 |
+
"""
|
| 91 |
+
opts = dnnlib.EasyDict(kwargs)
|
| 92 |
+
torch.multiprocessing.set_start_method('spawn')
|
| 93 |
+
dist.init()
|
| 94 |
+
|
| 95 |
+
# Initialize config dict.
|
| 96 |
+
c = dnnlib.EasyDict()
|
| 97 |
+
c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
|
| 98 |
+
c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
|
| 99 |
+
c.network_kwargs = dnnlib.EasyDict()
|
| 100 |
+
c.loss_kwargs = dnnlib.EasyDict()
|
| 101 |
+
c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
|
| 102 |
+
|
| 103 |
+
# Validate dataset options.
|
| 104 |
+
try:
|
| 105 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
|
| 106 |
+
dataset_name = dataset_obj.name
|
| 107 |
+
c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
|
| 108 |
+
c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
|
| 109 |
+
if opts.cond and not dataset_obj.has_labels:
|
| 110 |
+
raise click.ClickException('--cond=True requires labels specified in dataset.json')
|
| 111 |
+
del dataset_obj # conserve memory
|
| 112 |
+
except IOError as err:
|
| 113 |
+
raise click.ClickException(f'--data: {err}')
|
| 114 |
+
|
| 115 |
+
# Network architecture.
|
| 116 |
+
if opts.arch == 'ddpmpp':
|
| 117 |
+
c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
|
| 118 |
+
c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
|
| 119 |
+
elif opts.arch == 'ncsnpp':
|
| 120 |
+
c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
|
| 121 |
+
c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
|
| 122 |
+
else:
|
| 123 |
+
assert opts.arch == 'adm'
|
| 124 |
+
c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
|
| 125 |
+
|
| 126 |
+
# Preconditioning & loss function.
|
| 127 |
+
if opts.precond == 'vp':
|
| 128 |
+
c.network_kwargs.class_name = 'training.networks.VPPrecond'
|
| 129 |
+
c.loss_kwargs.class_name = 'training.loss.VPLoss'
|
| 130 |
+
elif opts.precond == 've':
|
| 131 |
+
c.network_kwargs.class_name = 'training.networks.VEPrecond'
|
| 132 |
+
c.loss_kwargs.class_name = 'training.loss.VELoss'
|
| 133 |
+
else:
|
| 134 |
+
assert opts.precond == 'edm'
|
| 135 |
+
c.network_kwargs.class_name = 'training.networks.EDMPrecond'
|
| 136 |
+
c.loss_kwargs.class_name = 'training.loss.EDMLoss'
|
| 137 |
+
|
| 138 |
+
# Network options.
|
| 139 |
+
if opts.cbase is not None:
|
| 140 |
+
c.network_kwargs.model_channels = opts.cbase
|
| 141 |
+
if opts.cres is not None:
|
| 142 |
+
c.network_kwargs.channel_mult = opts.cres
|
| 143 |
+
if opts.augment:
|
| 144 |
+
c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
|
| 145 |
+
c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
|
| 146 |
+
c.network_kwargs.augment_dim = 9
|
| 147 |
+
c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
|
| 148 |
+
|
| 149 |
+
# Training options.
|
| 150 |
+
c.total_kimg = max(int(opts.duration * 1000), 1)
|
| 151 |
+
c.ema_halflife_kimg = int(opts.ema * 1000)
|
| 152 |
+
c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
|
| 153 |
+
c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
|
| 154 |
+
c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
|
| 155 |
+
|
| 156 |
+
# Random seed.
|
| 157 |
+
if opts.seed is not None:
|
| 158 |
+
c.seed = opts.seed
|
| 159 |
+
else:
|
| 160 |
+
seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
|
| 161 |
+
torch.distributed.broadcast(seed, src=0)
|
| 162 |
+
c.seed = int(seed)
|
| 163 |
+
|
| 164 |
+
# Transfer learning and resume.
|
| 165 |
+
if opts.transfer is not None:
|
| 166 |
+
if opts.resume is not None:
|
| 167 |
+
raise click.ClickException('--transfer and --resume cannot be specified at the same time')
|
| 168 |
+
c.resume_pkl = opts.transfer
|
| 169 |
+
c.ema_rampup_ratio = None
|
| 170 |
+
elif opts.resume is not None:
|
| 171 |
+
match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
|
| 172 |
+
if not match or not os.path.isfile(opts.resume):
|
| 173 |
+
raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
|
| 174 |
+
c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
|
| 175 |
+
c.resume_kimg = int(match.group(1))
|
| 176 |
+
c.resume_state_dump = opts.resume
|
| 177 |
+
|
| 178 |
+
# Description string.
|
| 179 |
+
cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
|
| 180 |
+
dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
|
| 181 |
+
desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
|
| 182 |
+
if opts.desc is not None:
|
| 183 |
+
desc += f'-{opts.desc}'
|
| 184 |
+
|
| 185 |
+
# Pick output directory.
|
| 186 |
+
if dist.get_rank() != 0:
|
| 187 |
+
c.run_dir = None
|
| 188 |
+
elif opts.nosubdir:
|
| 189 |
+
c.run_dir = opts.outdir
|
| 190 |
+
else:
|
| 191 |
+
prev_run_dirs = []
|
| 192 |
+
if os.path.isdir(opts.outdir):
|
| 193 |
+
prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
|
| 194 |
+
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
|
| 195 |
+
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
|
| 196 |
+
cur_run_id = max(prev_run_ids, default=-1) + 1
|
| 197 |
+
c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
|
| 198 |
+
assert not os.path.exists(c.run_dir)
|
| 199 |
+
|
| 200 |
+
# Print options.
|
| 201 |
+
dist.print0()
|
| 202 |
+
dist.print0('Training options:')
|
| 203 |
+
dist.print0(json.dumps(c, indent=2))
|
| 204 |
+
dist.print0()
|
| 205 |
+
dist.print0(f'Output directory: {c.run_dir}')
|
| 206 |
+
dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
|
| 207 |
+
dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
|
| 208 |
+
dist.print0(f'Network architecture: {opts.arch}')
|
| 209 |
+
dist.print0(f'Preconditioning & loss: {opts.precond}')
|
| 210 |
+
dist.print0(f'Number of GPUs: {dist.get_world_size()}')
|
| 211 |
+
dist.print0(f'Batch size: {c.batch_size}')
|
| 212 |
+
dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
|
| 213 |
+
dist.print0()
|
| 214 |
+
|
| 215 |
+
# Dry run?
|
| 216 |
+
if opts.dry_run:
|
| 217 |
+
dist.print0('Dry run; exiting.')
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
# Create output directory.
|
| 221 |
+
dist.print0('Creating output directory...')
|
| 222 |
+
if dist.get_rank() == 0:
|
| 223 |
+
os.makedirs(c.run_dir, exist_ok=True)
|
| 224 |
+
with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
|
| 225 |
+
json.dump(c, f, indent=2)
|
| 226 |
+
dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
|
| 227 |
+
|
| 228 |
+
# Train.
|
| 229 |
+
training_loop.training_loop(**c)
|
| 230 |
+
|
| 231 |
+
#----------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
main()
|
| 235 |
+
|
| 236 |
+
#----------------------------------------------------------------------------
|
edm/training/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
# empty
|
edm/training/augment.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Augmentation pipeline used in the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models".
|
| 10 |
+
Built around the same concepts that were originally proposed in the paper
|
| 11 |
+
"Training Generative Adversarial Networks with Limited Data"."""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from torch_utils import persistence
|
| 16 |
+
from torch_utils import misc
|
| 17 |
+
|
| 18 |
+
#----------------------------------------------------------------------------
|
| 19 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
| 20 |
+
|
| 21 |
+
wavelets = {
|
| 22 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
| 23 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
| 24 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
| 25 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
| 26 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
| 27 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
| 28 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
| 29 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
| 30 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
| 31 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
| 32 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
| 33 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
| 34 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
| 35 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
| 36 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
| 37 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
#----------------------------------------------------------------------------
|
| 41 |
+
# Helpers for constructing transformation matrices.
|
| 42 |
+
|
| 43 |
+
def matrix(*rows, device=None):
|
| 44 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
| 45 |
+
elems = [x for row in rows for x in row]
|
| 46 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
| 47 |
+
if len(ref) == 0:
|
| 48 |
+
return misc.constant(np.asarray(rows), device=device)
|
| 49 |
+
assert device is None or device == ref[0].device
|
| 50 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
| 51 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
| 52 |
+
|
| 53 |
+
def translate2d(tx, ty, **kwargs):
|
| 54 |
+
return matrix(
|
| 55 |
+
[1, 0, tx],
|
| 56 |
+
[0, 1, ty],
|
| 57 |
+
[0, 0, 1],
|
| 58 |
+
**kwargs)
|
| 59 |
+
|
| 60 |
+
def translate3d(tx, ty, tz, **kwargs):
|
| 61 |
+
return matrix(
|
| 62 |
+
[1, 0, 0, tx],
|
| 63 |
+
[0, 1, 0, ty],
|
| 64 |
+
[0, 0, 1, tz],
|
| 65 |
+
[0, 0, 0, 1],
|
| 66 |
+
**kwargs)
|
| 67 |
+
|
| 68 |
+
def scale2d(sx, sy, **kwargs):
|
| 69 |
+
return matrix(
|
| 70 |
+
[sx, 0, 0],
|
| 71 |
+
[0, sy, 0],
|
| 72 |
+
[0, 0, 1],
|
| 73 |
+
**kwargs)
|
| 74 |
+
|
| 75 |
+
def scale3d(sx, sy, sz, **kwargs):
|
| 76 |
+
return matrix(
|
| 77 |
+
[sx, 0, 0, 0],
|
| 78 |
+
[0, sy, 0, 0],
|
| 79 |
+
[0, 0, sz, 0],
|
| 80 |
+
[0, 0, 0, 1],
|
| 81 |
+
**kwargs)
|
| 82 |
+
|
| 83 |
+
def rotate2d(theta, **kwargs):
|
| 84 |
+
return matrix(
|
| 85 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
| 86 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
| 87 |
+
[0, 0, 1],
|
| 88 |
+
**kwargs)
|
| 89 |
+
|
| 90 |
+
def rotate3d(v, theta, **kwargs):
|
| 91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
| 92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
| 93 |
+
return matrix(
|
| 94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
| 95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
| 96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
| 97 |
+
[0, 0, 0, 1],
|
| 98 |
+
**kwargs)
|
| 99 |
+
|
| 100 |
+
def translate2d_inv(tx, ty, **kwargs):
|
| 101 |
+
return translate2d(-tx, -ty, **kwargs)
|
| 102 |
+
|
| 103 |
+
def scale2d_inv(sx, sy, **kwargs):
|
| 104 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
| 105 |
+
|
| 106 |
+
def rotate2d_inv(theta, **kwargs):
|
| 107 |
+
return rotate2d(-theta, **kwargs)
|
| 108 |
+
|
| 109 |
+
#----------------------------------------------------------------------------
|
| 110 |
+
# Augmentation pipeline main class.
|
| 111 |
+
# All augmentations are disabled by default; individual augmentations can
|
| 112 |
+
# be enabled by setting their probability multipliers to 1.
|
| 113 |
+
|
| 114 |
+
@persistence.persistent_class
|
| 115 |
+
class AugmentPipe:
|
| 116 |
+
def __init__(self, p=1,
|
| 117 |
+
xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125,
|
| 118 |
+
scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125,
|
| 119 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.p = float(p) # Overall multiplier for augmentation probability.
|
| 123 |
+
|
| 124 |
+
# Pixel blitting.
|
| 125 |
+
self.xflip = float(xflip) # Probability multiplier for x-flip.
|
| 126 |
+
self.yflip = float(yflip) # Probability multiplier for y-flip.
|
| 127 |
+
self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation.
|
| 128 |
+
self.translate_int = float(translate_int) # Probability multiplier for integer translation.
|
| 129 |
+
self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions.
|
| 130 |
+
|
| 131 |
+
# Geometric transformations.
|
| 132 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
| 133 |
+
self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation.
|
| 134 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
| 135 |
+
self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation.
|
| 136 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
| 137 |
+
self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle.
|
| 138 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
| 139 |
+
self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
|
| 140 |
+
self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions.
|
| 141 |
+
|
| 142 |
+
# Color transformations.
|
| 143 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
| 144 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
| 145 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
| 146 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
| 147 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
| 148 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
| 149 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
| 150 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
| 151 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
| 152 |
+
|
| 153 |
+
def __call__(self, images):
|
| 154 |
+
N, C, H, W = images.shape
|
| 155 |
+
device = images.device
|
| 156 |
+
labels = [torch.zeros([images.shape[0], 0], device=device)]
|
| 157 |
+
|
| 158 |
+
# ---------------
|
| 159 |
+
# Pixel blitting.
|
| 160 |
+
# ---------------
|
| 161 |
+
|
| 162 |
+
if self.xflip > 0:
|
| 163 |
+
w = torch.randint(2, [N, 1, 1, 1], device=device)
|
| 164 |
+
w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w))
|
| 165 |
+
images = torch.where(w == 1, images.flip(3), images)
|
| 166 |
+
labels += [w]
|
| 167 |
+
|
| 168 |
+
if self.yflip > 0:
|
| 169 |
+
w = torch.randint(2, [N, 1, 1, 1], device=device)
|
| 170 |
+
w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w))
|
| 171 |
+
images = torch.where(w == 1, images.flip(2), images)
|
| 172 |
+
labels += [w]
|
| 173 |
+
|
| 174 |
+
if self.rotate_int > 0:
|
| 175 |
+
w = torch.randint(4, [N, 1, 1, 1], device=device)
|
| 176 |
+
w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w))
|
| 177 |
+
images = torch.where((w == 1) | (w == 2), images.flip(3), images)
|
| 178 |
+
images = torch.where((w == 2) | (w == 3), images.flip(2), images)
|
| 179 |
+
images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images)
|
| 180 |
+
labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)]
|
| 181 |
+
|
| 182 |
+
if self.translate_int > 0:
|
| 183 |
+
w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1
|
| 184 |
+
w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w))
|
| 185 |
+
tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64)
|
| 186 |
+
ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64)
|
| 187 |
+
b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij')
|
| 188 |
+
x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs()
|
| 189 |
+
y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs()
|
| 190 |
+
images = images.flatten()[(((b * C) + c) * H + y) * W + x]
|
| 191 |
+
labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)]
|
| 192 |
+
|
| 193 |
+
# ------------------------------------------------
|
| 194 |
+
# Select parameters for geometric transformations.
|
| 195 |
+
# ------------------------------------------------
|
| 196 |
+
|
| 197 |
+
I_3 = torch.eye(3, device=device)
|
| 198 |
+
G_inv = I_3
|
| 199 |
+
|
| 200 |
+
if self.scale > 0:
|
| 201 |
+
w = torch.randn([N], device=device)
|
| 202 |
+
w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w))
|
| 203 |
+
s = w.mul(self.scale_std).exp2()
|
| 204 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
| 205 |
+
labels += [w]
|
| 206 |
+
|
| 207 |
+
if self.rotate_frac > 0:
|
| 208 |
+
w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max)
|
| 209 |
+
w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w))
|
| 210 |
+
G_inv = G_inv @ rotate2d_inv(-w)
|
| 211 |
+
labels += [w.cos() - 1, w.sin()]
|
| 212 |
+
|
| 213 |
+
if self.aniso > 0:
|
| 214 |
+
w = torch.randn([N], device=device)
|
| 215 |
+
r = (torch.rand([N], device=device) * 2 - 1) * np.pi
|
| 216 |
+
w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w))
|
| 217 |
+
r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r))
|
| 218 |
+
s = w.mul(self.aniso_std).exp2()
|
| 219 |
+
G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r)
|
| 220 |
+
labels += [w * r.cos(), w * r.sin()]
|
| 221 |
+
|
| 222 |
+
if self.translate_frac > 0:
|
| 223 |
+
w = torch.randn([2, N], device=device)
|
| 224 |
+
w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w))
|
| 225 |
+
G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std))
|
| 226 |
+
labels += [w[0], w[1]]
|
| 227 |
+
|
| 228 |
+
# ----------------------------------
|
| 229 |
+
# Execute geometric transformations.
|
| 230 |
+
# ----------------------------------
|
| 231 |
+
|
| 232 |
+
if G_inv is not I_3:
|
| 233 |
+
cx = (W - 1) / 2
|
| 234 |
+
cy = (H - 1) / 2
|
| 235 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
| 236 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
| 237 |
+
Hz = np.asarray(wavelets['sym6'], dtype=np.float32)
|
| 238 |
+
Hz_pad = len(Hz) // 4
|
| 239 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
| 240 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
| 241 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
| 242 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
| 243 |
+
margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device))
|
| 244 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
| 245 |
+
|
| 246 |
+
# Pad image and adjust origin.
|
| 247 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
| 248 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
| 249 |
+
|
| 250 |
+
# Upsample.
|
| 251 |
+
conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
|
| 252 |
+
conv_pad = (len(Hz) + 1) // 2
|
| 253 |
+
images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1]
|
| 254 |
+
images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad])
|
| 255 |
+
images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :]
|
| 256 |
+
images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0])
|
| 257 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
| 258 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
| 259 |
+
|
| 260 |
+
# Execute transformation.
|
| 261 |
+
shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2]
|
| 262 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
| 263 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
| 264 |
+
images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
| 265 |
+
|
| 266 |
+
# Downsample and crop.
|
| 267 |
+
conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
|
| 268 |
+
conv_pad = (len(Hz) - 1) // 2
|
| 269 |
+
images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad]
|
| 270 |
+
images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :]
|
| 271 |
+
|
| 272 |
+
# --------------------------------------------
|
| 273 |
+
# Select parameters for color transformations.
|
| 274 |
+
# --------------------------------------------
|
| 275 |
+
|
| 276 |
+
I_4 = torch.eye(4, device=device)
|
| 277 |
+
M = I_4
|
| 278 |
+
luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device)
|
| 279 |
+
|
| 280 |
+
if self.brightness > 0:
|
| 281 |
+
w = torch.randn([N], device=device)
|
| 282 |
+
w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w))
|
| 283 |
+
b = w * self.brightness_std
|
| 284 |
+
M = translate3d(b, b, b) @ M
|
| 285 |
+
labels += [w]
|
| 286 |
+
|
| 287 |
+
if self.contrast > 0:
|
| 288 |
+
w = torch.randn([N], device=device)
|
| 289 |
+
w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w))
|
| 290 |
+
c = w.mul(self.contrast_std).exp2()
|
| 291 |
+
M = scale3d(c, c, c) @ M
|
| 292 |
+
labels += [w]
|
| 293 |
+
|
| 294 |
+
if self.lumaflip > 0:
|
| 295 |
+
w = torch.randint(2, [N, 1, 1], device=device)
|
| 296 |
+
w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w))
|
| 297 |
+
M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M
|
| 298 |
+
labels += [w]
|
| 299 |
+
|
| 300 |
+
if self.hue > 0:
|
| 301 |
+
w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max)
|
| 302 |
+
w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w))
|
| 303 |
+
M = rotate3d(luma_axis, w) @ M
|
| 304 |
+
labels += [w.cos() - 1, w.sin()]
|
| 305 |
+
|
| 306 |
+
if self.saturation > 0:
|
| 307 |
+
w = torch.randn([N, 1, 1], device=device)
|
| 308 |
+
w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w))
|
| 309 |
+
M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M
|
| 310 |
+
labels += [w]
|
| 311 |
+
|
| 312 |
+
# ------------------------------
|
| 313 |
+
# Execute color transformations.
|
| 314 |
+
# ------------------------------
|
| 315 |
+
|
| 316 |
+
if M is not I_4:
|
| 317 |
+
images = images.reshape([N, C, H * W])
|
| 318 |
+
if C == 3:
|
| 319 |
+
images = M[:, :3, :3] @ images + M[:, :3, 3:]
|
| 320 |
+
elif C == 1:
|
| 321 |
+
M = M[:, :3, :].mean(dim=1, keepdims=True)
|
| 322 |
+
images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:]
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
| 325 |
+
images = images.reshape([N, C, H, W])
|
| 326 |
+
|
| 327 |
+
labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1)
|
| 328 |
+
return images, labels
|
| 329 |
+
|
| 330 |
+
#----------------------------------------------------------------------------
|
edm/training/dataset.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Streaming images and labels from datasets created with dataset_tool.py."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import zipfile
|
| 13 |
+
import PIL.Image
|
| 14 |
+
import json
|
| 15 |
+
import torch
|
| 16 |
+
import dnnlib
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import pyspng
|
| 20 |
+
except ImportError:
|
| 21 |
+
pyspng = None
|
| 22 |
+
|
| 23 |
+
#----------------------------------------------------------------------------
|
| 24 |
+
# Abstract base class for datasets.
|
| 25 |
+
|
| 26 |
+
class Dataset(torch.utils.data.Dataset):
|
| 27 |
+
def __init__(self,
|
| 28 |
+
name, # Name of the dataset.
|
| 29 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
| 30 |
+
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
| 31 |
+
use_labels = False, # Enable conditioning labels? False = label dimension is zero.
|
| 32 |
+
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
| 33 |
+
random_seed = 0, # Random seed to use when applying max_size.
|
| 34 |
+
cache = False, # Cache images in CPU memory?
|
| 35 |
+
):
|
| 36 |
+
self._name = name
|
| 37 |
+
self._raw_shape = list(raw_shape)
|
| 38 |
+
self._use_labels = use_labels
|
| 39 |
+
self._cache = cache
|
| 40 |
+
self._cached_images = dict() # {raw_idx: np.ndarray, ...}
|
| 41 |
+
self._raw_labels = None
|
| 42 |
+
self._label_shape = None
|
| 43 |
+
|
| 44 |
+
# Apply max_size.
|
| 45 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
| 46 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
| 47 |
+
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
|
| 48 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
| 49 |
+
|
| 50 |
+
# Apply xflip.
|
| 51 |
+
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
|
| 52 |
+
if xflip:
|
| 53 |
+
self._raw_idx = np.tile(self._raw_idx, 2)
|
| 54 |
+
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
| 55 |
+
|
| 56 |
+
def _get_raw_labels(self):
|
| 57 |
+
if self._raw_labels is None:
|
| 58 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
| 59 |
+
if self._raw_labels is None:
|
| 60 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
| 61 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
| 62 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
| 63 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
| 64 |
+
if self._raw_labels.dtype == np.int64:
|
| 65 |
+
assert self._raw_labels.ndim == 1
|
| 66 |
+
assert np.all(self._raw_labels >= 0)
|
| 67 |
+
return self._raw_labels
|
| 68 |
+
|
| 69 |
+
def close(self): # to be overridden by subclass
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def _load_raw_image(self, raw_idx): # to be overridden by subclass
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
def _load_raw_labels(self): # to be overridden by subclass
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
def __getstate__(self):
|
| 79 |
+
return dict(self.__dict__, _raw_labels=None)
|
| 80 |
+
|
| 81 |
+
def __del__(self):
|
| 82 |
+
try:
|
| 83 |
+
self.close()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
return self._raw_idx.size
|
| 89 |
+
|
| 90 |
+
def __getitem__(self, idx):
|
| 91 |
+
raw_idx = self._raw_idx[idx]
|
| 92 |
+
image = self._cached_images.get(raw_idx, None)
|
| 93 |
+
if image is None:
|
| 94 |
+
image = self._load_raw_image(raw_idx)
|
| 95 |
+
if self._cache:
|
| 96 |
+
self._cached_images[raw_idx] = image
|
| 97 |
+
assert isinstance(image, np.ndarray)
|
| 98 |
+
assert list(image.shape) == self.image_shape
|
| 99 |
+
assert image.dtype == np.uint8
|
| 100 |
+
if self._xflip[idx]:
|
| 101 |
+
assert image.ndim == 3 # CHW
|
| 102 |
+
image = image[:, :, ::-1]
|
| 103 |
+
return image.copy(), self.get_label(idx)
|
| 104 |
+
|
| 105 |
+
def get_label(self, idx):
|
| 106 |
+
label = self._get_raw_labels()[self._raw_idx[idx]]
|
| 107 |
+
if label.dtype == np.int64:
|
| 108 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
| 109 |
+
onehot[label] = 1
|
| 110 |
+
label = onehot
|
| 111 |
+
return label.copy()
|
| 112 |
+
|
| 113 |
+
def get_details(self, idx):
|
| 114 |
+
d = dnnlib.EasyDict()
|
| 115 |
+
d.raw_idx = int(self._raw_idx[idx])
|
| 116 |
+
d.xflip = (int(self._xflip[idx]) != 0)
|
| 117 |
+
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
|
| 118 |
+
return d
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def name(self):
|
| 122 |
+
return self._name
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def image_shape(self):
|
| 126 |
+
return list(self._raw_shape[1:])
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def num_channels(self):
|
| 130 |
+
assert len(self.image_shape) == 3 # CHW
|
| 131 |
+
return self.image_shape[0]
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def resolution(self):
|
| 135 |
+
assert len(self.image_shape) == 3 # CHW
|
| 136 |
+
assert self.image_shape[1] == self.image_shape[2]
|
| 137 |
+
return self.image_shape[1]
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def label_shape(self):
|
| 141 |
+
if self._label_shape is None:
|
| 142 |
+
raw_labels = self._get_raw_labels()
|
| 143 |
+
if raw_labels.dtype == np.int64:
|
| 144 |
+
self._label_shape = [int(np.max(raw_labels)) + 1]
|
| 145 |
+
else:
|
| 146 |
+
self._label_shape = raw_labels.shape[1:]
|
| 147 |
+
return list(self._label_shape)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def label_dim(self):
|
| 151 |
+
assert len(self.label_shape) == 1
|
| 152 |
+
return self.label_shape[0]
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def has_labels(self):
|
| 156 |
+
return any(x != 0 for x in self.label_shape)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def has_onehot_labels(self):
|
| 160 |
+
return self._get_raw_labels().dtype == np.int64
|
| 161 |
+
|
| 162 |
+
#----------------------------------------------------------------------------
|
| 163 |
+
# Dataset subclass that loads images recursively from the specified directory
|
| 164 |
+
# or ZIP file.
|
| 165 |
+
|
| 166 |
+
class ImageFolderDataset(Dataset):
|
| 167 |
+
def __init__(self,
|
| 168 |
+
path, # Path to directory or zip.
|
| 169 |
+
resolution = None, # Ensure specific resolution, None = highest available.
|
| 170 |
+
use_pyspng = True, # Use pyspng if available?
|
| 171 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
| 172 |
+
):
|
| 173 |
+
self._path = path
|
| 174 |
+
self._use_pyspng = use_pyspng
|
| 175 |
+
self._zipfile = None
|
| 176 |
+
|
| 177 |
+
if os.path.isdir(self._path):
|
| 178 |
+
self._type = 'dir'
|
| 179 |
+
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
|
| 180 |
+
elif self._file_ext(self._path) == '.zip':
|
| 181 |
+
self._type = 'zip'
|
| 182 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
| 183 |
+
else:
|
| 184 |
+
raise IOError('Path must point to a directory or zip')
|
| 185 |
+
|
| 186 |
+
PIL.Image.init()
|
| 187 |
+
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
|
| 188 |
+
if len(self._image_fnames) == 0:
|
| 189 |
+
raise IOError('No image files found in the specified path')
|
| 190 |
+
|
| 191 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
| 192 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
| 193 |
+
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
|
| 194 |
+
raise IOError('Image files do not match the specified resolution')
|
| 195 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def _file_ext(fname):
|
| 199 |
+
return os.path.splitext(fname)[1].lower()
|
| 200 |
+
|
| 201 |
+
def _get_zipfile(self):
|
| 202 |
+
assert self._type == 'zip'
|
| 203 |
+
if self._zipfile is None:
|
| 204 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
| 205 |
+
return self._zipfile
|
| 206 |
+
|
| 207 |
+
def _open_file(self, fname):
|
| 208 |
+
if self._type == 'dir':
|
| 209 |
+
return open(os.path.join(self._path, fname), 'rb')
|
| 210 |
+
if self._type == 'zip':
|
| 211 |
+
return self._get_zipfile().open(fname, 'r')
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def close(self):
|
| 215 |
+
try:
|
| 216 |
+
if self._zipfile is not None:
|
| 217 |
+
self._zipfile.close()
|
| 218 |
+
finally:
|
| 219 |
+
self._zipfile = None
|
| 220 |
+
|
| 221 |
+
def __getstate__(self):
|
| 222 |
+
return dict(super().__getstate__(), _zipfile=None)
|
| 223 |
+
|
| 224 |
+
def _load_raw_image(self, raw_idx):
|
| 225 |
+
fname = self._image_fnames[raw_idx]
|
| 226 |
+
with self._open_file(fname) as f:
|
| 227 |
+
if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
|
| 228 |
+
image = pyspng.load(f.read())
|
| 229 |
+
else:
|
| 230 |
+
image = np.array(PIL.Image.open(f))
|
| 231 |
+
if image.ndim == 2:
|
| 232 |
+
image = image[:, :, np.newaxis] # HW => HWC
|
| 233 |
+
image = image.transpose(2, 0, 1) # HWC => CHW
|
| 234 |
+
return image
|
| 235 |
+
|
| 236 |
+
def _load_raw_labels(self):
|
| 237 |
+
fname = 'dataset.json'
|
| 238 |
+
if fname not in self._all_fnames:
|
| 239 |
+
return None
|
| 240 |
+
with self._open_file(fname) as f:
|
| 241 |
+
labels = json.load(f)['labels']
|
| 242 |
+
if labels is None:
|
| 243 |
+
return None
|
| 244 |
+
labels = dict(labels)
|
| 245 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
|
| 246 |
+
labels = np.array(labels)
|
| 247 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 248 |
+
return labels
|
| 249 |
+
|
| 250 |
+
#----------------------------------------------------------------------------
|
edm/training/loss.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Loss functions used in the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from edm.torch_utils import persistence
|
| 13 |
+
|
| 14 |
+
#----------------------------------------------------------------------------
|
| 15 |
+
# Loss function corresponding to the variance preserving (VP) formulation
|
| 16 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 17 |
+
# Differential Equations".
|
| 18 |
+
|
| 19 |
+
@persistence.persistent_class
|
| 20 |
+
class VPLoss:
|
| 21 |
+
def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
|
| 22 |
+
self.beta_d = beta_d
|
| 23 |
+
self.beta_min = beta_min
|
| 24 |
+
self.epsilon_t = epsilon_t
|
| 25 |
+
|
| 26 |
+
def __call__(self, net, images, labels, augment_pipe=None):
|
| 27 |
+
rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
|
| 28 |
+
sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
|
| 29 |
+
weight = 1 / sigma ** 2
|
| 30 |
+
y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
|
| 31 |
+
n = torch.randn_like(y) * sigma
|
| 32 |
+
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
|
| 33 |
+
loss = weight * ((D_yn - y) ** 2)
|
| 34 |
+
return loss
|
| 35 |
+
|
| 36 |
+
def sigma(self, t):
|
| 37 |
+
t = torch.as_tensor(t)
|
| 38 |
+
return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
|
| 39 |
+
|
| 40 |
+
#----------------------------------------------------------------------------
|
| 41 |
+
# Loss function corresponding to the variance exploding (VE) formulation
|
| 42 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 43 |
+
# Differential Equations".
|
| 44 |
+
|
| 45 |
+
@persistence.persistent_class
|
| 46 |
+
class VELoss:
|
| 47 |
+
def __init__(self, sigma_min=0.02, sigma_max=100):
|
| 48 |
+
self.sigma_min = sigma_min
|
| 49 |
+
self.sigma_max = sigma_max
|
| 50 |
+
|
| 51 |
+
def __call__(self, net, images, labels, augment_pipe=None):
|
| 52 |
+
rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
|
| 53 |
+
sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
|
| 54 |
+
weight = 1 / sigma ** 2
|
| 55 |
+
y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
|
| 56 |
+
n = torch.randn_like(y) * sigma
|
| 57 |
+
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
|
| 58 |
+
loss = weight * ((D_yn - y) ** 2)
|
| 59 |
+
return loss
|
| 60 |
+
|
| 61 |
+
#----------------------------------------------------------------------------
|
| 62 |
+
# Improved loss function proposed in the paper "Elucidating the Design Space
|
| 63 |
+
# of Diffusion-Based Generative Models" (EDM).
|
| 64 |
+
|
| 65 |
+
@persistence.persistent_class
|
| 66 |
+
class EDMLoss:
|
| 67 |
+
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
|
| 68 |
+
self.P_mean = P_mean
|
| 69 |
+
self.P_std = P_std
|
| 70 |
+
self.sigma_data = sigma_data
|
| 71 |
+
|
| 72 |
+
def __call__(self, net, images, labels=None, augment_pipe=None):
|
| 73 |
+
rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
|
| 74 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
|
| 75 |
+
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
|
| 76 |
+
y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
|
| 77 |
+
n = torch.randn_like(y) * sigma
|
| 78 |
+
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
|
| 79 |
+
loss = weight * ((D_yn - y) ** 2)
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
#----------------------------------------------------------------------------
|
edm/training/networks.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Model architectures and preconditioning schemes used in the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch_utils import persistence
|
| 14 |
+
from torch.nn.functional import silu
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
# Unified routine for initializing weights and biases.
|
| 18 |
+
|
| 19 |
+
def weight_init(shape, mode, fan_in, fan_out):
|
| 20 |
+
if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
|
| 21 |
+
if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
|
| 22 |
+
if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
|
| 23 |
+
if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
|
| 24 |
+
raise ValueError(f'Invalid init mode "{mode}"')
|
| 25 |
+
|
| 26 |
+
#----------------------------------------------------------------------------
|
| 27 |
+
# Fully-connected layer.
|
| 28 |
+
|
| 29 |
+
@persistence.persistent_class
|
| 30 |
+
class Linear(torch.nn.Module):
|
| 31 |
+
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.in_features = in_features
|
| 34 |
+
self.out_features = out_features
|
| 35 |
+
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
|
| 36 |
+
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
|
| 37 |
+
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
x = x @ self.weight.to(x.dtype).t()
|
| 41 |
+
if self.bias is not None:
|
| 42 |
+
x = x.add_(self.bias.to(x.dtype))
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
#----------------------------------------------------------------------------
|
| 46 |
+
# Convolutional layer with optional up/downsampling.
|
| 47 |
+
|
| 48 |
+
@persistence.persistent_class
|
| 49 |
+
class Conv2d(torch.nn.Module):
|
| 50 |
+
def __init__(self,
|
| 51 |
+
in_channels, out_channels, kernel, bias=True, up=False, down=False,
|
| 52 |
+
resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0,
|
| 53 |
+
):
|
| 54 |
+
assert not (up and down)
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.in_channels = in_channels
|
| 57 |
+
self.out_channels = out_channels
|
| 58 |
+
self.up = up
|
| 59 |
+
self.down = down
|
| 60 |
+
self.fused_resample = fused_resample
|
| 61 |
+
init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel)
|
| 62 |
+
self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None
|
| 63 |
+
self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None
|
| 64 |
+
f = torch.as_tensor(resample_filter, dtype=torch.float32)
|
| 65 |
+
f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
|
| 66 |
+
self.register_buffer('resample_filter', f if up or down else None)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
w = self.weight.to(x.dtype) if self.weight is not None else None
|
| 70 |
+
b = self.bias.to(x.dtype) if self.bias is not None else None
|
| 71 |
+
f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None
|
| 72 |
+
w_pad = w.shape[-1] // 2 if w is not None else 0
|
| 73 |
+
f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
|
| 74 |
+
|
| 75 |
+
if self.fused_resample and self.up and w is not None:
|
| 76 |
+
x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0))
|
| 77 |
+
x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
|
| 78 |
+
elif self.fused_resample and self.down and w is not None:
|
| 79 |
+
x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad)
|
| 80 |
+
x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2)
|
| 81 |
+
else:
|
| 82 |
+
if self.up:
|
| 83 |
+
x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
|
| 84 |
+
if self.down:
|
| 85 |
+
x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
|
| 86 |
+
if w is not None:
|
| 87 |
+
x = torch.nn.functional.conv2d(x, w, padding=w_pad)
|
| 88 |
+
if b is not None:
|
| 89 |
+
x = x.add_(b.reshape(1, -1, 1, 1))
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
#----------------------------------------------------------------------------
|
| 93 |
+
# Group normalization.
|
| 94 |
+
|
| 95 |
+
@persistence.persistent_class
|
| 96 |
+
class GroupNorm(torch.nn.Module):
|
| 97 |
+
def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_groups = min(num_groups, num_channels // min_channels_per_group)
|
| 100 |
+
self.eps = eps
|
| 101 |
+
self.weight = torch.nn.Parameter(torch.ones(num_channels))
|
| 102 |
+
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
#----------------------------------------------------------------------------
|
| 109 |
+
# Attention weight computation, i.e., softmax(Q^T * K).
|
| 110 |
+
# Performs all computation using FP32, but uses the original datatype for
|
| 111 |
+
# inputs/outputs/gradients to conserve memory.
|
| 112 |
+
|
| 113 |
+
class AttentionOp(torch.autograd.Function):
|
| 114 |
+
@staticmethod
|
| 115 |
+
def forward(ctx, q, k):
|
| 116 |
+
w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
|
| 117 |
+
ctx.save_for_backward(q, k, w)
|
| 118 |
+
return w
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def backward(ctx, dw):
|
| 122 |
+
q, k, w = ctx.saved_tensors
|
| 123 |
+
db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
|
| 124 |
+
dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
|
| 125 |
+
dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
|
| 126 |
+
return dq, dk
|
| 127 |
+
|
| 128 |
+
#----------------------------------------------------------------------------
|
| 129 |
+
# Unified U-Net block with optional up/downsampling and self-attention.
|
| 130 |
+
# Represents the union of all features employed by the DDPM++, NCSN++, and
|
| 131 |
+
# ADM architectures.
|
| 132 |
+
|
| 133 |
+
@persistence.persistent_class
|
| 134 |
+
class UNetBlock(torch.nn.Module):
|
| 135 |
+
def __init__(self,
|
| 136 |
+
in_channels, out_channels, emb_channels, up=False, down=False, attention=False,
|
| 137 |
+
num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5,
|
| 138 |
+
resample_filter=[1,1], resample_proj=False, adaptive_scale=True,
|
| 139 |
+
init=dict(), init_zero=dict(init_weight=0), init_attn=None,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.in_channels = in_channels
|
| 143 |
+
self.out_channels = out_channels
|
| 144 |
+
self.emb_channels = emb_channels
|
| 145 |
+
self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head
|
| 146 |
+
self.dropout = dropout
|
| 147 |
+
self.skip_scale = skip_scale
|
| 148 |
+
self.adaptive_scale = adaptive_scale
|
| 149 |
+
|
| 150 |
+
self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
|
| 151 |
+
self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init)
|
| 152 |
+
self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init)
|
| 153 |
+
self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
|
| 154 |
+
self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero)
|
| 155 |
+
|
| 156 |
+
self.skip = None
|
| 157 |
+
if out_channels != in_channels or up or down:
|
| 158 |
+
kernel = 1 if resample_proj or out_channels!= in_channels else 0
|
| 159 |
+
self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init)
|
| 160 |
+
|
| 161 |
+
if self.num_heads:
|
| 162 |
+
self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
|
| 163 |
+
self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init))
|
| 164 |
+
self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero)
|
| 165 |
+
|
| 166 |
+
def forward(self, x, emb):
|
| 167 |
+
orig = x
|
| 168 |
+
x = self.conv0(silu(self.norm0(x)))
|
| 169 |
+
|
| 170 |
+
params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
|
| 171 |
+
if self.adaptive_scale:
|
| 172 |
+
scale, shift = params.chunk(chunks=2, dim=1)
|
| 173 |
+
x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
|
| 174 |
+
else:
|
| 175 |
+
x = silu(self.norm1(x.add_(params)))
|
| 176 |
+
|
| 177 |
+
x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
|
| 178 |
+
x = x.add_(self.skip(orig) if self.skip is not None else orig)
|
| 179 |
+
x = x * self.skip_scale
|
| 180 |
+
|
| 181 |
+
if self.num_heads:
|
| 182 |
+
q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
|
| 183 |
+
w = AttentionOp.apply(q, k)
|
| 184 |
+
a = torch.einsum('nqk,nck->ncq', w, v)
|
| 185 |
+
x = self.proj(a.reshape(*x.shape)).add_(x)
|
| 186 |
+
x = x * self.skip_scale
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
#----------------------------------------------------------------------------
|
| 190 |
+
# Timestep embedding used in the DDPM++ and ADM architectures.
|
| 191 |
+
|
| 192 |
+
@persistence.persistent_class
|
| 193 |
+
class PositionalEmbedding(torch.nn.Module):
|
| 194 |
+
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.num_channels = num_channels
|
| 197 |
+
self.max_positions = max_positions
|
| 198 |
+
self.endpoint = endpoint
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
|
| 202 |
+
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
| 203 |
+
freqs = (1 / self.max_positions) ** freqs
|
| 204 |
+
x = x.ger(freqs.to(x.dtype))
|
| 205 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
#----------------------------------------------------------------------------
|
| 209 |
+
# Timestep embedding used in the NCSN++ architecture.
|
| 210 |
+
|
| 211 |
+
@persistence.persistent_class
|
| 212 |
+
class FourierEmbedding(torch.nn.Module):
|
| 213 |
+
def __init__(self, num_channels, scale=16):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.register_buffer('freqs', torch.randn(num_channels // 2) * scale)
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
|
| 219 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
#----------------------------------------------------------------------------
|
| 223 |
+
# Reimplementation of the DDPM++ and NCSN++ architectures from the paper
|
| 224 |
+
# "Score-Based Generative Modeling through Stochastic Differential
|
| 225 |
+
# Equations". Equivalent to the original implementation by Song et al.,
|
| 226 |
+
# available at https://github.com/yang-song/score_sde_pytorch
|
| 227 |
+
|
| 228 |
+
@persistence.persistent_class
|
| 229 |
+
class SongUNet(torch.nn.Module):
|
| 230 |
+
def __init__(self,
|
| 231 |
+
img_resolution, # Image resolution at input/output.
|
| 232 |
+
in_channels, # Number of color channels at input.
|
| 233 |
+
out_channels, # Number of color channels at output.
|
| 234 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 235 |
+
augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
|
| 236 |
+
|
| 237 |
+
model_channels = 128, # Base multiplier for the number of channels.
|
| 238 |
+
channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels.
|
| 239 |
+
channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
|
| 240 |
+
num_blocks = 4, # Number of residual blocks per resolution.
|
| 241 |
+
attn_resolutions = [16], # List of resolutions with self-attention.
|
| 242 |
+
dropout = 0.10, # Dropout probability of intermediate activations.
|
| 243 |
+
label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
|
| 244 |
+
|
| 245 |
+
embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
|
| 246 |
+
channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
|
| 247 |
+
encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
|
| 248 |
+
decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
|
| 249 |
+
resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
|
| 250 |
+
):
|
| 251 |
+
assert embedding_type in ['fourier', 'positional']
|
| 252 |
+
assert encoder_type in ['standard', 'skip', 'residual']
|
| 253 |
+
assert decoder_type in ['standard', 'skip']
|
| 254 |
+
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.label_dropout = label_dropout
|
| 257 |
+
emb_channels = model_channels * channel_mult_emb
|
| 258 |
+
noise_channels = model_channels * channel_mult_noise
|
| 259 |
+
init = dict(init_mode='xavier_uniform')
|
| 260 |
+
init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
|
| 261 |
+
init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
|
| 262 |
+
block_kwargs = dict(
|
| 263 |
+
emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6,
|
| 264 |
+
resample_filter=resample_filter, resample_proj=True, adaptive_scale=False,
|
| 265 |
+
init=init, init_zero=init_zero, init_attn=init_attn,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Mapping.
|
| 269 |
+
self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels)
|
| 270 |
+
self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None
|
| 271 |
+
self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None
|
| 272 |
+
self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
|
| 273 |
+
self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
|
| 274 |
+
|
| 275 |
+
# Encoder.
|
| 276 |
+
self.enc = torch.nn.ModuleDict()
|
| 277 |
+
cout = in_channels
|
| 278 |
+
caux = in_channels
|
| 279 |
+
for level, mult in enumerate(channel_mult):
|
| 280 |
+
res = img_resolution >> level
|
| 281 |
+
if level == 0:
|
| 282 |
+
cin = cout
|
| 283 |
+
cout = model_channels
|
| 284 |
+
self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
|
| 285 |
+
else:
|
| 286 |
+
self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
|
| 287 |
+
if encoder_type == 'skip':
|
| 288 |
+
self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter)
|
| 289 |
+
self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init)
|
| 290 |
+
if encoder_type == 'residual':
|
| 291 |
+
self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init)
|
| 292 |
+
caux = cout
|
| 293 |
+
for idx in range(num_blocks):
|
| 294 |
+
cin = cout
|
| 295 |
+
cout = model_channels * mult
|
| 296 |
+
attn = (res in attn_resolutions)
|
| 297 |
+
self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
|
| 298 |
+
skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name]
|
| 299 |
+
|
| 300 |
+
# Decoder.
|
| 301 |
+
self.dec = torch.nn.ModuleDict()
|
| 302 |
+
for level, mult in reversed(list(enumerate(channel_mult))):
|
| 303 |
+
res = img_resolution >> level
|
| 304 |
+
if level == len(channel_mult) - 1:
|
| 305 |
+
self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
|
| 306 |
+
self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
|
| 307 |
+
else:
|
| 308 |
+
self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
|
| 309 |
+
for idx in range(num_blocks + 1):
|
| 310 |
+
cin = cout + skips.pop()
|
| 311 |
+
cout = model_channels * mult
|
| 312 |
+
attn = (idx == num_blocks and res in attn_resolutions)
|
| 313 |
+
self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
|
| 314 |
+
if decoder_type == 'skip' or level == 0:
|
| 315 |
+
if decoder_type == 'skip' and level < len(channel_mult) - 1:
|
| 316 |
+
self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter)
|
| 317 |
+
self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6)
|
| 318 |
+
self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
|
| 319 |
+
|
| 320 |
+
def forward(self, x, noise_labels, class_labels, augment_labels=None):
|
| 321 |
+
# Mapping.
|
| 322 |
+
emb = self.map_noise(noise_labels)
|
| 323 |
+
emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
|
| 324 |
+
if self.map_label is not None:
|
| 325 |
+
tmp = class_labels
|
| 326 |
+
if self.training and self.label_dropout:
|
| 327 |
+
tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
|
| 328 |
+
emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features))
|
| 329 |
+
if self.map_augment is not None and augment_labels is not None:
|
| 330 |
+
emb = emb + self.map_augment(augment_labels)
|
| 331 |
+
emb = silu(self.map_layer0(emb))
|
| 332 |
+
emb = silu(self.map_layer1(emb))
|
| 333 |
+
|
| 334 |
+
# Encoder.
|
| 335 |
+
skips = []
|
| 336 |
+
aux = x
|
| 337 |
+
for name, block in self.enc.items():
|
| 338 |
+
if 'aux_down' in name:
|
| 339 |
+
aux = block(aux)
|
| 340 |
+
elif 'aux_skip' in name:
|
| 341 |
+
x = skips[-1] = x + block(aux)
|
| 342 |
+
elif 'aux_residual' in name:
|
| 343 |
+
x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2)
|
| 344 |
+
else:
|
| 345 |
+
x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
|
| 346 |
+
skips.append(x)
|
| 347 |
+
|
| 348 |
+
# Decoder.
|
| 349 |
+
aux = None
|
| 350 |
+
tmp = None
|
| 351 |
+
for name, block in self.dec.items():
|
| 352 |
+
if 'aux_up' in name:
|
| 353 |
+
aux = block(aux)
|
| 354 |
+
elif 'aux_norm' in name:
|
| 355 |
+
tmp = block(x)
|
| 356 |
+
elif 'aux_conv' in name:
|
| 357 |
+
tmp = block(silu(tmp))
|
| 358 |
+
aux = tmp if aux is None else tmp + aux
|
| 359 |
+
else:
|
| 360 |
+
if x.shape[1] != block.in_channels:
|
| 361 |
+
x = torch.cat([x, skips.pop()], dim=1)
|
| 362 |
+
x = block(x, emb)
|
| 363 |
+
return aux
|
| 364 |
+
|
| 365 |
+
#----------------------------------------------------------------------------
|
| 366 |
+
# Reimplementation of the ADM architecture from the paper
|
| 367 |
+
# "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the
|
| 368 |
+
# original implementation by Dhariwal and Nichol, available at
|
| 369 |
+
# https://github.com/openai/guided-diffusion
|
| 370 |
+
|
| 371 |
+
@persistence.persistent_class
|
| 372 |
+
class DhariwalUNet(torch.nn.Module):
|
| 373 |
+
def __init__(self,
|
| 374 |
+
img_resolution, # Image resolution at input/output.
|
| 375 |
+
in_channels, # Number of color channels at input.
|
| 376 |
+
out_channels, # Number of color channels at output.
|
| 377 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 378 |
+
augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
|
| 379 |
+
|
| 380 |
+
model_channels = 192, # Base multiplier for the number of channels.
|
| 381 |
+
channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
|
| 382 |
+
channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
|
| 383 |
+
num_blocks = 3, # Number of residual blocks per resolution.
|
| 384 |
+
attn_resolutions = [32,16,8], # List of resolutions with self-attention.
|
| 385 |
+
dropout = 0.10, # List of resolutions with self-attention.
|
| 386 |
+
label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
|
| 387 |
+
):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.label_dropout = label_dropout
|
| 390 |
+
emb_channels = model_channels * channel_mult_emb
|
| 391 |
+
init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3))
|
| 392 |
+
init_zero = dict(init_mode='kaiming_uniform', init_weight=0, init_bias=0)
|
| 393 |
+
block_kwargs = dict(emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero)
|
| 394 |
+
|
| 395 |
+
# Mapping.
|
| 396 |
+
self.map_noise = PositionalEmbedding(num_channels=model_channels)
|
| 397 |
+
self.map_augment = Linear(in_features=augment_dim, out_features=model_channels, bias=False, **init_zero) if augment_dim else None
|
| 398 |
+
self.map_layer0 = Linear(in_features=model_channels, out_features=emb_channels, **init)
|
| 399 |
+
self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
|
| 400 |
+
self.map_label = Linear(in_features=label_dim, out_features=emb_channels, bias=False, init_mode='kaiming_normal', init_weight=np.sqrt(label_dim)) if label_dim else None
|
| 401 |
+
|
| 402 |
+
# Encoder.
|
| 403 |
+
self.enc = torch.nn.ModuleDict()
|
| 404 |
+
cout = in_channels
|
| 405 |
+
for level, mult in enumerate(channel_mult):
|
| 406 |
+
res = img_resolution >> level
|
| 407 |
+
if level == 0:
|
| 408 |
+
cin = cout
|
| 409 |
+
cout = model_channels * mult
|
| 410 |
+
self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
|
| 411 |
+
else:
|
| 412 |
+
self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
|
| 413 |
+
for idx in range(num_blocks):
|
| 414 |
+
cin = cout
|
| 415 |
+
cout = model_channels * mult
|
| 416 |
+
self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs)
|
| 417 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 418 |
+
|
| 419 |
+
# Decoder.
|
| 420 |
+
self.dec = torch.nn.ModuleDict()
|
| 421 |
+
for level, mult in reversed(list(enumerate(channel_mult))):
|
| 422 |
+
res = img_resolution >> level
|
| 423 |
+
if level == len(channel_mult) - 1:
|
| 424 |
+
self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
|
| 425 |
+
self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
|
| 426 |
+
else:
|
| 427 |
+
self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
|
| 428 |
+
for idx in range(num_blocks + 1):
|
| 429 |
+
cin = cout + skips.pop()
|
| 430 |
+
cout = model_channels * mult
|
| 431 |
+
self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs)
|
| 432 |
+
self.out_norm = GroupNorm(num_channels=cout)
|
| 433 |
+
self.out_conv = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
|
| 434 |
+
|
| 435 |
+
def forward(self, x, noise_labels, class_labels, augment_labels=None):
|
| 436 |
+
# Mapping.
|
| 437 |
+
emb = self.map_noise(noise_labels)
|
| 438 |
+
if self.map_augment is not None and augment_labels is not None:
|
| 439 |
+
emb = emb + self.map_augment(augment_labels)
|
| 440 |
+
emb = silu(self.map_layer0(emb))
|
| 441 |
+
emb = self.map_layer1(emb)
|
| 442 |
+
if self.map_label is not None:
|
| 443 |
+
tmp = class_labels
|
| 444 |
+
if self.training and self.label_dropout:
|
| 445 |
+
tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
|
| 446 |
+
emb = emb + self.map_label(tmp)
|
| 447 |
+
emb = silu(emb)
|
| 448 |
+
|
| 449 |
+
# Encoder.
|
| 450 |
+
skips = []
|
| 451 |
+
for block in self.enc.values():
|
| 452 |
+
x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
|
| 453 |
+
skips.append(x)
|
| 454 |
+
|
| 455 |
+
# Decoder.
|
| 456 |
+
for block in self.dec.values():
|
| 457 |
+
if x.shape[1] != block.in_channels:
|
| 458 |
+
x = torch.cat([x, skips.pop()], dim=1)
|
| 459 |
+
x = block(x, emb)
|
| 460 |
+
x = self.out_conv(silu(self.out_norm(x)))
|
| 461 |
+
return x
|
| 462 |
+
|
| 463 |
+
#----------------------------------------------------------------------------
|
| 464 |
+
# Preconditioning corresponding to the variance preserving (VP) formulation
|
| 465 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 466 |
+
# Differential Equations".
|
| 467 |
+
|
| 468 |
+
@persistence.persistent_class
|
| 469 |
+
class VPPrecond(torch.nn.Module):
|
| 470 |
+
def __init__(self,
|
| 471 |
+
img_resolution, # Image resolution.
|
| 472 |
+
img_channels, # Number of color channels.
|
| 473 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 474 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 475 |
+
beta_d = 19.9, # Extent of the noise level schedule.
|
| 476 |
+
beta_min = 0.1, # Initial slope of the noise level schedule.
|
| 477 |
+
M = 1000, # Original number of timesteps in the DDPM formulation.
|
| 478 |
+
epsilon_t = 1e-5, # Minimum t-value used during training.
|
| 479 |
+
model_type = 'SongUNet', # Class name of the underlying model.
|
| 480 |
+
**model_kwargs, # Keyword arguments for the underlying model.
|
| 481 |
+
):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.img_resolution = img_resolution
|
| 484 |
+
self.img_channels = img_channels
|
| 485 |
+
self.label_dim = label_dim
|
| 486 |
+
self.use_fp16 = use_fp16
|
| 487 |
+
self.beta_d = beta_d
|
| 488 |
+
self.beta_min = beta_min
|
| 489 |
+
self.M = M
|
| 490 |
+
self.epsilon_t = epsilon_t
|
| 491 |
+
self.sigma_min = float(self.sigma(epsilon_t))
|
| 492 |
+
self.sigma_max = float(self.sigma(1))
|
| 493 |
+
self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
|
| 494 |
+
|
| 495 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 496 |
+
x = x.to(torch.float32)
|
| 497 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 498 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 499 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 500 |
+
|
| 501 |
+
c_skip = 1
|
| 502 |
+
c_out = -sigma
|
| 503 |
+
c_in = 1 / (sigma ** 2 + 1).sqrt()
|
| 504 |
+
c_noise = (self.M - 1) * self.sigma_inv(sigma)
|
| 505 |
+
|
| 506 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 507 |
+
assert F_x.dtype == dtype
|
| 508 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 509 |
+
return D_x
|
| 510 |
+
|
| 511 |
+
def sigma(self, t):
|
| 512 |
+
t = torch.as_tensor(t)
|
| 513 |
+
return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
|
| 514 |
+
|
| 515 |
+
def sigma_inv(self, sigma):
|
| 516 |
+
sigma = torch.as_tensor(sigma)
|
| 517 |
+
return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d
|
| 518 |
+
|
| 519 |
+
def round_sigma(self, sigma):
|
| 520 |
+
return torch.as_tensor(sigma)
|
| 521 |
+
|
| 522 |
+
#----------------------------------------------------------------------------
|
| 523 |
+
# Preconditioning corresponding to the variance exploding (VE) formulation
|
| 524 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 525 |
+
# Differential Equations".
|
| 526 |
+
|
| 527 |
+
@persistence.persistent_class
|
| 528 |
+
class VEPrecond(torch.nn.Module):
|
| 529 |
+
def __init__(self,
|
| 530 |
+
img_resolution, # Image resolution.
|
| 531 |
+
img_channels, # Number of color channels.
|
| 532 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 533 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 534 |
+
sigma_min = 0.02, # Minimum supported noise level.
|
| 535 |
+
sigma_max = 100, # Maximum supported noise level.
|
| 536 |
+
model_type = 'SongUNet', # Class name of the underlying model.
|
| 537 |
+
**model_kwargs, # Keyword arguments for the underlying model.
|
| 538 |
+
):
|
| 539 |
+
super().__init__()
|
| 540 |
+
self.img_resolution = img_resolution
|
| 541 |
+
self.img_channels = img_channels
|
| 542 |
+
self.label_dim = label_dim
|
| 543 |
+
self.use_fp16 = use_fp16
|
| 544 |
+
self.sigma_min = sigma_min
|
| 545 |
+
self.sigma_max = sigma_max
|
| 546 |
+
self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
|
| 547 |
+
|
| 548 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 549 |
+
x = x.to(torch.float32)
|
| 550 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 551 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 552 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 553 |
+
|
| 554 |
+
c_skip = 1
|
| 555 |
+
c_out = sigma
|
| 556 |
+
c_in = 1
|
| 557 |
+
c_noise = (0.5 * sigma).log()
|
| 558 |
+
|
| 559 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 560 |
+
assert F_x.dtype == dtype
|
| 561 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 562 |
+
return D_x
|
| 563 |
+
|
| 564 |
+
def round_sigma(self, sigma):
|
| 565 |
+
return torch.as_tensor(sigma)
|
| 566 |
+
|
| 567 |
+
#----------------------------------------------------------------------------
|
| 568 |
+
# Preconditioning corresponding to improved DDPM (iDDPM) formulation from
|
| 569 |
+
# the paper "Improved Denoising Diffusion Probabilistic Models".
|
| 570 |
+
|
| 571 |
+
@persistence.persistent_class
|
| 572 |
+
class iDDPMPrecond(torch.nn.Module):
|
| 573 |
+
def __init__(self,
|
| 574 |
+
img_resolution, # Image resolution.
|
| 575 |
+
img_channels, # Number of color channels.
|
| 576 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 577 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 578 |
+
C_1 = 0.001, # Timestep adjustment at low noise levels.
|
| 579 |
+
C_2 = 0.008, # Timestep adjustment at high noise levels.
|
| 580 |
+
M = 1000, # Original number of timesteps in the DDPM formulation.
|
| 581 |
+
model_type = 'DhariwalUNet', # Class name of the underlying model.
|
| 582 |
+
**model_kwargs, # Keyword arguments for the underlying model.
|
| 583 |
+
):
|
| 584 |
+
super().__init__()
|
| 585 |
+
self.img_resolution = img_resolution
|
| 586 |
+
self.img_channels = img_channels
|
| 587 |
+
self.label_dim = label_dim
|
| 588 |
+
self.use_fp16 = use_fp16
|
| 589 |
+
self.C_1 = C_1
|
| 590 |
+
self.C_2 = C_2
|
| 591 |
+
self.M = M
|
| 592 |
+
self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels*2, label_dim=label_dim, **model_kwargs)
|
| 593 |
+
|
| 594 |
+
u = torch.zeros(M + 1)
|
| 595 |
+
for j in range(M, 0, -1): # M, ..., 1
|
| 596 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
| 597 |
+
self.register_buffer('u', u)
|
| 598 |
+
self.sigma_min = float(u[M - 1])
|
| 599 |
+
self.sigma_max = float(u[0])
|
| 600 |
+
|
| 601 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 602 |
+
x = x.to(torch.float32)
|
| 603 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 604 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 605 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 606 |
+
|
| 607 |
+
c_skip = 1
|
| 608 |
+
c_out = -sigma
|
| 609 |
+
c_in = 1 / (sigma ** 2 + 1).sqrt()
|
| 610 |
+
c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
|
| 611 |
+
|
| 612 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 613 |
+
assert F_x.dtype == dtype
|
| 614 |
+
D_x = c_skip * x + c_out * F_x[:, :self.img_channels].to(torch.float32)
|
| 615 |
+
return D_x
|
| 616 |
+
|
| 617 |
+
def alpha_bar(self, j):
|
| 618 |
+
j = torch.as_tensor(j)
|
| 619 |
+
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
|
| 620 |
+
|
| 621 |
+
def round_sigma(self, sigma, return_index=False):
|
| 622 |
+
sigma = torch.as_tensor(sigma)
|
| 623 |
+
index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
|
| 624 |
+
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
|
| 625 |
+
return result.reshape(sigma.shape).to(sigma.device)
|
| 626 |
+
|
| 627 |
+
#----------------------------------------------------------------------------
|
| 628 |
+
# Improved preconditioning proposed in the paper "Elucidating the Design
|
| 629 |
+
# Space of Diffusion-Based Generative Models" (EDM).
|
| 630 |
+
|
| 631 |
+
@persistence.persistent_class
|
| 632 |
+
class EDMPrecond(torch.nn.Module):
|
| 633 |
+
def __init__(self,
|
| 634 |
+
img_resolution, # Image resolution.
|
| 635 |
+
img_channels, # Number of color channels.
|
| 636 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 637 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 638 |
+
sigma_min = 0, # Minimum supported noise level.
|
| 639 |
+
sigma_max = float('inf'), # Maximum supported noise level.
|
| 640 |
+
sigma_data = 0.5, # Expected standard deviation of the training data.
|
| 641 |
+
model_type = 'DhariwalUNet', # Class name of the underlying model.
|
| 642 |
+
**model_kwargs, # Keyword arguments for the underlying model.
|
| 643 |
+
):
|
| 644 |
+
super().__init__()
|
| 645 |
+
self.img_resolution = img_resolution
|
| 646 |
+
self.img_channels = img_channels
|
| 647 |
+
self.label_dim = label_dim
|
| 648 |
+
self.use_fp16 = use_fp16
|
| 649 |
+
self.sigma_min = sigma_min
|
| 650 |
+
self.sigma_max = sigma_max
|
| 651 |
+
self.sigma_data = sigma_data
|
| 652 |
+
self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
|
| 653 |
+
|
| 654 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 655 |
+
x = x.to(torch.float32)
|
| 656 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 657 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 658 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 659 |
+
|
| 660 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 661 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
| 662 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
| 663 |
+
c_noise = sigma.log() / 4
|
| 664 |
+
|
| 665 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 666 |
+
assert F_x.dtype == dtype
|
| 667 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 668 |
+
return D_x
|
| 669 |
+
|
| 670 |
+
def round_sigma(self, sigma):
|
| 671 |
+
return torch.as_tensor(sigma)
|
| 672 |
+
|
| 673 |
+
#----------------------------------------------------------------------------
|
edm/training/training_loop.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Main training loop."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
import copy
|
| 13 |
+
import json
|
| 14 |
+
import pickle
|
| 15 |
+
import psutil
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import dnnlib
|
| 19 |
+
from torch_utils import distributed as dist
|
| 20 |
+
from torch_utils import training_stats
|
| 21 |
+
from torch_utils import misc
|
| 22 |
+
|
| 23 |
+
#----------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def training_loop(
|
| 26 |
+
run_dir = '.', # Output directory.
|
| 27 |
+
dataset_kwargs = {}, # Options for training set.
|
| 28 |
+
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
|
| 29 |
+
network_kwargs = {}, # Options for model and preconditioning.
|
| 30 |
+
loss_kwargs = {}, # Options for loss function.
|
| 31 |
+
optimizer_kwargs = {}, # Options for optimizer.
|
| 32 |
+
augment_kwargs = None, # Options for augmentation pipeline, None = disable.
|
| 33 |
+
seed = 0, # Global random seed.
|
| 34 |
+
batch_size = 512, # Total batch size for one training iteration.
|
| 35 |
+
batch_gpu = None, # Limit batch size per GPU, None = no limit.
|
| 36 |
+
total_kimg = 200000, # Training duration, measured in thousands of training images.
|
| 37 |
+
ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
|
| 38 |
+
ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
|
| 39 |
+
lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
|
| 40 |
+
loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
|
| 41 |
+
kimg_per_tick = 50, # Interval of progress prints.
|
| 42 |
+
snapshot_ticks = 50, # How often to save network snapshots, None = disable.
|
| 43 |
+
state_dump_ticks = 500, # How often to dump training state, None = disable.
|
| 44 |
+
resume_pkl = None, # Start from the given network snapshot, None = random initialization.
|
| 45 |
+
resume_state_dump = None, # Start from the given training state, None = reset training state.
|
| 46 |
+
resume_kimg = 0, # Start from the given training progress.
|
| 47 |
+
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
|
| 48 |
+
device = torch.device('cuda'),
|
| 49 |
+
):
|
| 50 |
+
# Initialize.
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
|
| 53 |
+
torch.manual_seed(np.random.randint(1 << 31))
|
| 54 |
+
torch.backends.cudnn.benchmark = cudnn_benchmark
|
| 55 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 56 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 57 |
+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
| 58 |
+
|
| 59 |
+
# Select batch size per GPU.
|
| 60 |
+
batch_gpu_total = batch_size // dist.get_world_size()
|
| 61 |
+
if batch_gpu is None or batch_gpu > batch_gpu_total:
|
| 62 |
+
batch_gpu = batch_gpu_total
|
| 63 |
+
num_accumulation_rounds = batch_gpu_total // batch_gpu
|
| 64 |
+
assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
|
| 65 |
+
|
| 66 |
+
# Load dataset.
|
| 67 |
+
dist.print0('Loading dataset...')
|
| 68 |
+
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
|
| 69 |
+
dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
|
| 70 |
+
dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
|
| 71 |
+
|
| 72 |
+
# Construct network.
|
| 73 |
+
dist.print0('Constructing network...')
|
| 74 |
+
interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
|
| 75 |
+
net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
|
| 76 |
+
net.train().requires_grad_(True).to(device)
|
| 77 |
+
if dist.get_rank() == 0:
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
|
| 80 |
+
sigma = torch.ones([batch_gpu], device=device)
|
| 81 |
+
labels = torch.zeros([batch_gpu, net.label_dim], device=device)
|
| 82 |
+
misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
|
| 83 |
+
|
| 84 |
+
# Setup optimizer.
|
| 85 |
+
dist.print0('Setting up optimizer...')
|
| 86 |
+
loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
|
| 87 |
+
optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
|
| 88 |
+
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
|
| 89 |
+
ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
|
| 90 |
+
ema = copy.deepcopy(net).eval().requires_grad_(False)
|
| 91 |
+
|
| 92 |
+
# Resume training from previous snapshot.
|
| 93 |
+
if resume_pkl is not None:
|
| 94 |
+
dist.print0(f'Loading network weights from "{resume_pkl}"...')
|
| 95 |
+
if dist.get_rank() != 0:
|
| 96 |
+
torch.distributed.barrier() # rank 0 goes first
|
| 97 |
+
with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
|
| 98 |
+
data = pickle.load(f)
|
| 99 |
+
if dist.get_rank() == 0:
|
| 100 |
+
torch.distributed.barrier() # other ranks follow
|
| 101 |
+
misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
|
| 102 |
+
misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
|
| 103 |
+
del data # conserve memory
|
| 104 |
+
if resume_state_dump:
|
| 105 |
+
dist.print0(f'Loading training state from "{resume_state_dump}"...')
|
| 106 |
+
data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
|
| 107 |
+
misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
|
| 108 |
+
optimizer.load_state_dict(data['optimizer_state'])
|
| 109 |
+
del data # conserve memory
|
| 110 |
+
|
| 111 |
+
# Train.
|
| 112 |
+
dist.print0(f'Training for {total_kimg} kimg...')
|
| 113 |
+
dist.print0()
|
| 114 |
+
cur_nimg = resume_kimg * 1000
|
| 115 |
+
cur_tick = 0
|
| 116 |
+
tick_start_nimg = cur_nimg
|
| 117 |
+
tick_start_time = time.time()
|
| 118 |
+
maintenance_time = tick_start_time - start_time
|
| 119 |
+
dist.update_progress(cur_nimg // 1000, total_kimg)
|
| 120 |
+
stats_jsonl = None
|
| 121 |
+
while True:
|
| 122 |
+
|
| 123 |
+
# Accumulate gradients.
|
| 124 |
+
optimizer.zero_grad(set_to_none=True)
|
| 125 |
+
for round_idx in range(num_accumulation_rounds):
|
| 126 |
+
with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
|
| 127 |
+
images, labels = next(dataset_iterator)
|
| 128 |
+
images = images.to(device).to(torch.float32) / 127.5 - 1
|
| 129 |
+
labels = labels.to(device)
|
| 130 |
+
loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
|
| 131 |
+
training_stats.report('Loss/loss', loss)
|
| 132 |
+
loss.sum().mul(loss_scaling / batch_gpu_total).backward()
|
| 133 |
+
|
| 134 |
+
# Update weights.
|
| 135 |
+
for g in optimizer.param_groups:
|
| 136 |
+
g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
|
| 137 |
+
for param in net.parameters():
|
| 138 |
+
if param.grad is not None:
|
| 139 |
+
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
|
| 140 |
+
optimizer.step()
|
| 141 |
+
|
| 142 |
+
# Update EMA.
|
| 143 |
+
ema_halflife_nimg = ema_halflife_kimg * 1000
|
| 144 |
+
if ema_rampup_ratio is not None:
|
| 145 |
+
ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
|
| 146 |
+
ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
|
| 147 |
+
for p_ema, p_net in zip(ema.parameters(), net.parameters()):
|
| 148 |
+
p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
|
| 149 |
+
|
| 150 |
+
# Perform maintenance tasks once per tick.
|
| 151 |
+
cur_nimg += batch_size
|
| 152 |
+
done = (cur_nimg >= total_kimg * 1000)
|
| 153 |
+
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Print status line, accumulating the same information in training_stats.
|
| 157 |
+
tick_end_time = time.time()
|
| 158 |
+
fields = []
|
| 159 |
+
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
|
| 160 |
+
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
|
| 161 |
+
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
|
| 162 |
+
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
|
| 163 |
+
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
|
| 164 |
+
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
|
| 165 |
+
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
|
| 166 |
+
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
|
| 167 |
+
fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
|
| 168 |
+
torch.cuda.reset_peak_memory_stats()
|
| 169 |
+
dist.print0(' '.join(fields))
|
| 170 |
+
|
| 171 |
+
# Check for abort.
|
| 172 |
+
if (not done) and dist.should_stop():
|
| 173 |
+
done = True
|
| 174 |
+
dist.print0()
|
| 175 |
+
dist.print0('Aborting...')
|
| 176 |
+
|
| 177 |
+
# Save network snapshot.
|
| 178 |
+
if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
|
| 179 |
+
data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
|
| 180 |
+
for key, value in data.items():
|
| 181 |
+
if isinstance(value, torch.nn.Module):
|
| 182 |
+
value = copy.deepcopy(value).eval().requires_grad_(False)
|
| 183 |
+
misc.check_ddp_consistency(value)
|
| 184 |
+
data[key] = value.cpu()
|
| 185 |
+
del value # conserve memory
|
| 186 |
+
if dist.get_rank() == 0:
|
| 187 |
+
with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
|
| 188 |
+
pickle.dump(data, f)
|
| 189 |
+
del data # conserve memory
|
| 190 |
+
|
| 191 |
+
# Save full dump of the training state.
|
| 192 |
+
if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
|
| 193 |
+
torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
|
| 194 |
+
|
| 195 |
+
# Update logs.
|
| 196 |
+
training_stats.default_collector.update()
|
| 197 |
+
if dist.get_rank() == 0:
|
| 198 |
+
if stats_jsonl is None:
|
| 199 |
+
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
|
| 200 |
+
stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
|
| 201 |
+
stats_jsonl.flush()
|
| 202 |
+
dist.update_progress(cur_nimg // 1000, total_kimg)
|
| 203 |
+
|
| 204 |
+
# Update state.
|
| 205 |
+
cur_tick += 1
|
| 206 |
+
tick_start_nimg = cur_nimg
|
| 207 |
+
tick_start_time = time.time()
|
| 208 |
+
maintenance_time = tick_start_time - tick_end_time
|
| 209 |
+
if done:
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
# Done.
|
| 213 |
+
dist.print0()
|
| 214 |
+
dist.print0('Exiting...')
|
| 215 |
+
|
| 216 |
+
#----------------------------------------------------------------------------
|