daidedou commited on
Commit
df60d6b
·
1 Parent(s): 1245229

forgot this

Browse files
edm/Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ FROM nvcr.io/nvidia/pytorch:22.10-py3
9
+
10
+ ENV PYTHONDONTWRITEBYTECODE 1
11
+ ENV PYTHONUNBUFFERED 1
12
+
13
+ RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
14
+
15
+ WORKDIR /workspace
16
+
17
+ RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
18
+ ENTRYPOINT ["/entry.sh"]
edm/LICENSE.txt ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+ Attribution-NonCommercial-ShareAlike 4.0 International
4
+
5
+ =======================================================================
6
+
7
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
8
+ does not provide legal services or legal advice. Distribution of
9
+ Creative Commons public licenses does not create a lawyer-client or
10
+ other relationship. Creative Commons makes its licenses and related
11
+ information available on an "as-is" basis. Creative Commons gives no
12
+ warranties regarding its licenses, any material licensed under their
13
+ terms and conditions, or any related information. Creative Commons
14
+ disclaims all liability for damages resulting from their use to the
15
+ fullest extent possible.
16
+
17
+ Using Creative Commons Public Licenses
18
+
19
+ Creative Commons public licenses provide a standard set of terms and
20
+ conditions that creators and other rights holders may use to share
21
+ original works of authorship and other material subject to copyright
22
+ and certain other rights specified in the public license below. The
23
+ following considerations are for informational purposes only, are not
24
+ exhaustive, and do not form part of our licenses.
25
+
26
+ Considerations for licensors: Our public licenses are
27
+ intended for use by those authorized to give the public
28
+ permission to use material in ways otherwise restricted by
29
+ copyright and certain other rights. Our licenses are
30
+ irrevocable. Licensors should read and understand the terms
31
+ and conditions of the license they choose before applying it.
32
+ Licensors should also secure all rights necessary before
33
+ applying our licenses so that the public can reuse the
34
+ material as expected. Licensors should clearly mark any
35
+ material not subject to the license. This includes other CC-
36
+ licensed material, or material used under an exception or
37
+ limitation to copyright. More considerations for licensors:
38
+ wiki.creativecommons.org/Considerations_for_licensors
39
+
40
+ Considerations for the public: By using one of our public
41
+ licenses, a licensor grants the public permission to use the
42
+ licensed material under specified terms and conditions. If
43
+ the licensor's permission is not necessary for any reason--for
44
+ example, because of any applicable exception or limitation to
45
+ copyright--then that use is not regulated by the license. Our
46
+ licenses grant only permissions under copyright and certain
47
+ other rights that a licensor has authority to grant. Use of
48
+ the licensed material may still be restricted for other
49
+ reasons, including because others have copyright or other
50
+ rights in the material. A licensor may make special requests,
51
+ such as asking that all changes be marked or described.
52
+ Although not required by our licenses, you are encouraged to
53
+ respect those requests where reasonable. More considerations
54
+ for the public:
55
+ wiki.creativecommons.org/Considerations_for_licensees
56
+
57
+ =======================================================================
58
+
59
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
60
+ Public License
61
+
62
+ By exercising the Licensed Rights (defined below), You accept and agree
63
+ to be bound by the terms and conditions of this Creative Commons
64
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
65
+ ("Public License"). To the extent this Public License may be
66
+ interpreted as a contract, You are granted the Licensed Rights in
67
+ consideration of Your acceptance of these terms and conditions, and the
68
+ Licensor grants You such rights in consideration of benefits the
69
+ Licensor receives from making the Licensed Material available under
70
+ these terms and conditions.
71
+
72
+
73
+ Section 1 -- Definitions.
74
+
75
+ a. Adapted Material means material subject to Copyright and Similar
76
+ Rights that is derived from or based upon the Licensed Material
77
+ and in which the Licensed Material is translated, altered,
78
+ arranged, transformed, or otherwise modified in a manner requiring
79
+ permission under the Copyright and Similar Rights held by the
80
+ Licensor. For purposes of this Public License, where the Licensed
81
+ Material is a musical work, performance, or sound recording,
82
+ Adapted Material is always produced where the Licensed Material is
83
+ synched in timed relation with a moving image.
84
+
85
+ b. Adapter's License means the license You apply to Your Copyright
86
+ and Similar Rights in Your contributions to Adapted Material in
87
+ accordance with the terms and conditions of this Public License.
88
+
89
+ c. BY-NC-SA Compatible License means a license listed at
90
+ creativecommons.org/compatiblelicenses, approved by Creative
91
+ Commons as essentially the equivalent of this Public License.
92
+
93
+ d. Copyright and Similar Rights means copyright and/or similar rights
94
+ closely related to copyright including, without limitation,
95
+ performance, broadcast, sound recording, and Sui Generis Database
96
+ Rights, without regard to how the rights are labeled or
97
+ categorized. For purposes of this Public License, the rights
98
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
99
+ Rights.
100
+
101
+ e. Effective Technological Measures means those measures that, in the
102
+ absence of proper authority, may not be circumvented under laws
103
+ fulfilling obligations under Article 11 of the WIPO Copyright
104
+ Treaty adopted on December 20, 1996, and/or similar international
105
+ agreements.
106
+
107
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
108
+ any other exception or limitation to Copyright and Similar Rights
109
+ that applies to Your use of the Licensed Material.
110
+
111
+ g. License Elements means the license attributes listed in the name
112
+ of a Creative Commons Public License. The License Elements of this
113
+ Public License are Attribution, NonCommercial, and ShareAlike.
114
+
115
+ h. Licensed Material means the artistic or literary work, database,
116
+ or other material to which the Licensor applied this Public
117
+ License.
118
+
119
+ i. Licensed Rights means the rights granted to You subject to the
120
+ terms and conditions of this Public License, which are limited to
121
+ all Copyright and Similar Rights that apply to Your use of the
122
+ Licensed Material and that the Licensor has authority to license.
123
+
124
+ j. Licensor means the individual(s) or entity(ies) granting rights
125
+ under this Public License.
126
+
127
+ k. NonCommercial means not primarily intended for or directed towards
128
+ commercial advantage or monetary compensation. For purposes of
129
+ this Public License, the exchange of the Licensed Material for
130
+ other material subject to Copyright and Similar Rights by digital
131
+ file-sharing or similar means is NonCommercial provided there is
132
+ no payment of monetary compensation in connection with the
133
+ exchange.
134
+
135
+ l. Share means to provide material to the public by any means or
136
+ process that requires permission under the Licensed Rights, such
137
+ as reproduction, public display, public performance, distribution,
138
+ dissemination, communication, or importation, and to make material
139
+ available to the public including in ways that members of the
140
+ public may access the material from a place and at a time
141
+ individually chosen by them.
142
+
143
+ m. Sui Generis Database Rights means rights other than copyright
144
+ resulting from Directive 96/9/EC of the European Parliament and of
145
+ the Council of 11 March 1996 on the legal protection of databases,
146
+ as amended and/or succeeded, as well as other essentially
147
+ equivalent rights anywhere in the world.
148
+
149
+ n. You means the individual or entity exercising the Licensed Rights
150
+ under this Public License. Your has a corresponding meaning.
151
+
152
+
153
+ Section 2 -- Scope.
154
+
155
+ a. License grant.
156
+
157
+ 1. Subject to the terms and conditions of this Public License,
158
+ the Licensor hereby grants You a worldwide, royalty-free,
159
+ non-sublicensable, non-exclusive, irrevocable license to
160
+ exercise the Licensed Rights in the Licensed Material to:
161
+
162
+ a. reproduce and Share the Licensed Material, in whole or
163
+ in part, for NonCommercial purposes only; and
164
+
165
+ b. produce, reproduce, and Share Adapted Material for
166
+ NonCommercial purposes only.
167
+
168
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
169
+ Exceptions and Limitations apply to Your use, this Public
170
+ License does not apply, and You do not need to comply with
171
+ its terms and conditions.
172
+
173
+ 3. Term. The term of this Public License is specified in Section
174
+ 6(a).
175
+
176
+ 4. Media and formats; technical modifications allowed. The
177
+ Licensor authorizes You to exercise the Licensed Rights in
178
+ all media and formats whether now known or hereafter created,
179
+ and to make technical modifications necessary to do so. The
180
+ Licensor waives and/or agrees not to assert any right or
181
+ authority to forbid You from making technical modifications
182
+ necessary to exercise the Licensed Rights, including
183
+ technical modifications necessary to circumvent Effective
184
+ Technological Measures. For purposes of this Public License,
185
+ simply making modifications authorized by this Section 2(a)
186
+ (4) never produces Adapted Material.
187
+
188
+ 5. Downstream recipients.
189
+
190
+ a. Offer from the Licensor -- Licensed Material. Every
191
+ recipient of the Licensed Material automatically
192
+ receives an offer from the Licensor to exercise the
193
+ Licensed Rights under the terms and conditions of this
194
+ Public License.
195
+
196
+ b. Additional offer from the Licensor -- Adapted Material.
197
+ Every recipient of Adapted Material from You
198
+ automatically receives an offer from the Licensor to
199
+ exercise the Licensed Rights in the Adapted Material
200
+ under the conditions of the Adapter's License You apply.
201
+
202
+ c. No downstream restrictions. You may not offer or impose
203
+ any additional or different terms or conditions on, or
204
+ apply any Effective Technological Measures to, the
205
+ Licensed Material if doing so restricts exercise of the
206
+ Licensed Rights by any recipient of the Licensed
207
+ Material.
208
+
209
+ 6. No endorsement. Nothing in this Public License constitutes or
210
+ may be construed as permission to assert or imply that You
211
+ are, or that Your use of the Licensed Material is, connected
212
+ with, or sponsored, endorsed, or granted official status by,
213
+ the Licensor or others designated to receive attribution as
214
+ provided in Section 3(a)(1)(A)(i).
215
+
216
+ b. Other rights.
217
+
218
+ 1. Moral rights, such as the right of integrity, are not
219
+ licensed under this Public License, nor are publicity,
220
+ privacy, and/or other similar personality rights; however, to
221
+ the extent possible, the Licensor waives and/or agrees not to
222
+ assert any such rights held by the Licensor to the limited
223
+ extent necessary to allow You to exercise the Licensed
224
+ Rights, but not otherwise.
225
+
226
+ 2. Patent and trademark rights are not licensed under this
227
+ Public License.
228
+
229
+ 3. To the extent possible, the Licensor waives any right to
230
+ collect royalties from You for the exercise of the Licensed
231
+ Rights, whether directly or through a collecting society
232
+ under any voluntary or waivable statutory or compulsory
233
+ licensing scheme. In all other cases the Licensor expressly
234
+ reserves any right to collect such royalties, including when
235
+ the Licensed Material is used other than for NonCommercial
236
+ purposes.
237
+
238
+
239
+ Section 3 -- License Conditions.
240
+
241
+ Your exercise of the Licensed Rights is expressly made subject to the
242
+ following conditions.
243
+
244
+ a. Attribution.
245
+
246
+ 1. If You Share the Licensed Material (including in modified
247
+ form), You must:
248
+
249
+ a. retain the following if it is supplied by the Licensor
250
+ with the Licensed Material:
251
+
252
+ i. identification of the creator(s) of the Licensed
253
+ Material and any others designated to receive
254
+ attribution, in any reasonable manner requested by
255
+ the Licensor (including by pseudonym if
256
+ designated);
257
+
258
+ ii. a copyright notice;
259
+
260
+ iii. a notice that refers to this Public License;
261
+
262
+ iv. a notice that refers to the disclaimer of
263
+ warranties;
264
+
265
+ v. a URI or hyperlink to the Licensed Material to the
266
+ extent reasonably practicable;
267
+
268
+ b. indicate if You modified the Licensed Material and
269
+ retain an indication of any previous modifications; and
270
+
271
+ c. indicate the Licensed Material is licensed under this
272
+ Public License, and include the text of, or the URI or
273
+ hyperlink to, this Public License.
274
+
275
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
276
+ reasonable manner based on the medium, means, and context in
277
+ which You Share the Licensed Material. For example, it may be
278
+ reasonable to satisfy the conditions by providing a URI or
279
+ hyperlink to a resource that includes the required
280
+ information.
281
+ 3. If requested by the Licensor, You must remove any of the
282
+ information required by Section 3(a)(1)(A) to the extent
283
+ reasonably practicable.
284
+
285
+ b. ShareAlike.
286
+
287
+ In addition to the conditions in Section 3(a), if You Share
288
+ Adapted Material You produce, the following conditions also apply.
289
+
290
+ 1. The Adapter's License You apply must be a Creative Commons
291
+ license with the same License Elements, this version or
292
+ later, or a BY-NC-SA Compatible License.
293
+
294
+ 2. You must include the text of, or the URI or hyperlink to, the
295
+ Adapter's License You apply. You may satisfy this condition
296
+ in any reasonable manner based on the medium, means, and
297
+ context in which You Share Adapted Material.
298
+
299
+ 3. You may not offer or impose any additional or different terms
300
+ or conditions on, or apply any Effective Technological
301
+ Measures to, Adapted Material that restrict exercise of the
302
+ rights granted under the Adapter's License You apply.
303
+
304
+
305
+ Section 4 -- Sui Generis Database Rights.
306
+
307
+ Where the Licensed Rights include Sui Generis Database Rights that
308
+ apply to Your use of the Licensed Material:
309
+
310
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
311
+ to extract, reuse, reproduce, and Share all or a substantial
312
+ portion of the contents of the database for NonCommercial purposes
313
+ only;
314
+
315
+ b. if You include all or a substantial portion of the database
316
+ contents in a database in which You have Sui Generis Database
317
+ Rights, then the database in which You have Sui Generis Database
318
+ Rights (but not its individual contents) is Adapted Material,
319
+ including for purposes of Section 3(b); and
320
+
321
+ c. You must comply with the conditions in Section 3(a) if You Share
322
+ all or a substantial portion of the contents of the database.
323
+
324
+ For the avoidance of doubt, this Section 4 supplements and does not
325
+ replace Your obligations under this Public License where the Licensed
326
+ Rights include other Copyright and Similar Rights.
327
+
328
+
329
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
330
+
331
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
332
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
333
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
334
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
335
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
336
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
337
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
338
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
339
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
340
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
341
+
342
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
343
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
344
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
345
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
346
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
347
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
348
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
349
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
350
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
351
+
352
+ c. The disclaimer of warranties and limitation of liability provided
353
+ above shall be interpreted in a manner that, to the extent
354
+ possible, most closely approximates an absolute disclaimer and
355
+ waiver of all liability.
356
+
357
+
358
+ Section 6 -- Term and Termination.
359
+
360
+ a. This Public License applies for the term of the Copyright and
361
+ Similar Rights licensed here. However, if You fail to comply with
362
+ this Public License, then Your rights under this Public License
363
+ terminate automatically.
364
+
365
+ b. Where Your right to use the Licensed Material has terminated under
366
+ Section 6(a), it reinstates:
367
+
368
+ 1. automatically as of the date the violation is cured, provided
369
+ it is cured within 30 days of Your discovery of the
370
+ violation; or
371
+
372
+ 2. upon express reinstatement by the Licensor.
373
+
374
+ For the avoidance of doubt, this Section 6(b) does not affect any
375
+ right the Licensor may have to seek remedies for Your violations
376
+ of this Public License.
377
+
378
+ c. For the avoidance of doubt, the Licensor may also offer the
379
+ Licensed Material under separate terms or conditions or stop
380
+ distributing the Licensed Material at any time; however, doing so
381
+ will not terminate this Public License.
382
+
383
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
384
+ License.
385
+
386
+
387
+ Section 7 -- Other Terms and Conditions.
388
+
389
+ a. The Licensor shall not be bound by any additional or different
390
+ terms or conditions communicated by You unless expressly agreed.
391
+
392
+ b. Any arrangements, understandings, or agreements regarding the
393
+ Licensed Material not stated herein are separate from and
394
+ independent of the terms and conditions of this Public License.
395
+
396
+
397
+ Section 8 -- Interpretation.
398
+
399
+ a. For the avoidance of doubt, this Public License does not, and
400
+ shall not be interpreted to, reduce, limit, restrict, or impose
401
+ conditions on any use of the Licensed Material that could lawfully
402
+ be made without permission under this Public License.
403
+
404
+ b. To the extent possible, if any provision of this Public License is
405
+ deemed unenforceable, it shall be automatically reformed to the
406
+ minimum extent necessary to make it enforceable. If the provision
407
+ cannot be reformed, it shall be severed from this Public License
408
+ without affecting the enforceability of the remaining terms and
409
+ conditions.
410
+
411
+ c. No term or condition of this Public License will be waived and no
412
+ failure to comply consented to unless expressly agreed to by the
413
+ Licensor.
414
+
415
+ d. Nothing in this Public License constitutes or may be interpreted
416
+ as a limitation upon, or waiver of, any privileges and immunities
417
+ that apply to the Licensor or You, including from the legal
418
+ processes of any jurisdiction or authority.
419
+
420
+ =======================================================================
421
+
422
+ Creative Commons is not a party to its public
423
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
424
+ its public licenses to material it publishes and in those instances
425
+ will be considered the "Licensor." The text of the Creative Commons
426
+ public licenses is dedicated to the public domain under the CC0 Public
427
+ Domain Dedication. Except for the limited purpose of indicating that
428
+ material is shared under a Creative Commons public license or as
429
+ otherwise permitted by the Creative Commons policies published at
430
+ creativecommons.org/policies, Creative Commons does not authorize the
431
+ use of the trademark "Creative Commons" or any other trademark or logo
432
+ of Creative Commons without its prior written consent including,
433
+ without limitation, in connection with any unauthorized modifications
434
+ to any of its public licenses or any other arrangements,
435
+ understandings, or agreements concerning use of licensed material. For
436
+ the avoidance of doubt, this paragraph does not form part of the
437
+ public licenses.
438
+
439
+ Creative Commons may be contacted at creativecommons.org.
edm/README.md ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Elucidating the Design Space of Diffusion-Based Generative Models (EDM)<br><sub>Official PyTorch implementation of the NeurIPS 2022 paper</sub>
2
+
3
+ ![Teaser image](./docs/teaser-1920x640.jpg)
4
+
5
+ **Elucidating the Design Space of Diffusion-Based Generative Models**<br>
6
+ Tero Karras, Miika Aittala, Timo Aila, Samuli Laine
7
+ <br>https://arxiv.org/abs/2206.00364<br>
8
+
9
+ Abstract: *We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
10
+
11
+ For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
12
+
13
+ ## Requirements
14
+
15
+ * Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
16
+ * 1+ high-end NVIDIA GPU for sampling and 8+ GPUs for training. We have done all testing and development using V100 and A100 GPUs.
17
+ * 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch install instructions.
18
+ * Python libraries: See [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
19
+ - `conda env create -f environment.yml -n edm`
20
+ - `conda activate edm`
21
+ * Docker users:
22
+ - Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu).
23
+ - Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies.
24
+
25
+ ## Getting started
26
+
27
+ To reproduce the main results from our paper, simply run:
28
+
29
+ ```.bash
30
+ python example.py
31
+ ```
32
+
33
+ This is a minimal standalone script that loads the best pre-trained model for each dataset and generates a random 8x8 grid of images using the optimal sampler settings. Expected results:
34
+
35
+ | Dataset | Runtime | Reference image
36
+ | :------- | :------ | :--------------
37
+ | CIFAR-10 | ~6 sec | [`cifar10-32x32.png`](./docs/cifar10-32x32.png)
38
+ | FFHQ | ~28 sec | [`ffhq-64x64.png`](./docs/ffhq-64x64.png)
39
+ | AFHQv2 | ~28 sec | [`afhqv2-64x64.png`](./docs/afhqv2-64x64.png)
40
+ | ImageNet | ~5 min | [`imagenet-64x64.png`](./docs/imagenet-64x64.png)
41
+
42
+ The easiest way to explore different sampling strategies is to modify [`example.py`](./example.py) directly. You can also incorporate the pre-trained models and/or our proposed EDM sampler in your own code by simply copy-pasting the relevant bits. Note that the class definitions for the pre-trained models are stored within the pickles themselves and loaded automatically during unpickling via [`torch_utils.persistence`](./torch_utils/persistence.py). To use the models in external Python scripts, just make sure that `torch_utils` and `dnnlib` are accesible through `PYTHONPATH`.
43
+
44
+ **Docker**: You can run the example script using Docker as follows:
45
+
46
+ ```.bash
47
+ # Build the edm:latest image
48
+ docker build --tag edm:latest .
49
+
50
+ # Run the generate.py script using Docker:
51
+ docker run --gpus all -it --rm --user $(id -u):$(id -g) \
52
+ -v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \
53
+ edm:latest \
54
+ python example.py
55
+ ```
56
+
57
+ Note: The Docker image requires NVIDIA driver release `r520` or later.
58
+
59
+ The `docker run` invocation may look daunting, so let's unpack its contents here:
60
+
61
+ - `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root.
62
+ - ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir.
63
+ - `-e HOME=/scratch`: specify where to cache temporary files. Note: if you want more fine-grained control, you can instead set `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations.
64
+
65
+ ## Pre-trained models
66
+
67
+ We provide pre-trained models for our proposed training configuration (config F) as well as the baseline configuration (config A):
68
+
69
+ - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/)
70
+ - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/)
71
+
72
+ To generate a batch of images using a given model and sampler, run:
73
+
74
+ ```.bash
75
+ # Generate 64 images and save them as out/*.png
76
+ python generate.py --outdir=out --seeds=0-63 --batch=64 \
77
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
78
+ ```
79
+
80
+ Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`:
81
+
82
+ ```.bash
83
+ # Generate 1024 images using 2 GPUs
84
+ torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \
85
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
86
+ ```
87
+
88
+ The sampler settings can be controlled through command-line options; see [`python generate.py --help`](./docs/generate-help.txt) for more information. For best results, we recommend using the following settings for each dataset:
89
+
90
+ ```.bash
91
+ # For CIFAR-10 at 32x32, use deterministic sampling with 18 steps (NFE = 35)
92
+ python generate.py --outdir=out --steps=18 \
93
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
94
+
95
+ # For FFHQ and AFHQv2 at 64x64, use deterministic sampling with 40 steps (NFE = 79)
96
+ python generate.py --outdir=out --steps=40 \
97
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-vp.pkl
98
+
99
+ # For ImageNet at 64x64, use stochastic sampling with 256 steps (NFE = 511)
100
+ python generate.py --outdir=out --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003 \
101
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl
102
+ ```
103
+
104
+ Besides our proposed EDM sampler, `generate.py` can also be used to reproduce the sampler ablations from Section 3 of our paper. For example:
105
+
106
+ ```.bash
107
+ # Figure 2a, "Our reimplementation"
108
+ python generate.py --outdir=out --steps=512 --solver=euler --disc=vp --schedule=vp --scaling=vp \
109
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
110
+
111
+ # Figure 2a, "+ Heun & our {t_i}"
112
+ python generate.py --outdir=out --steps=128 --solver=heun --disc=edm --schedule=vp --scaling=vp \
113
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
114
+
115
+ # Figure 2a, "+ Our sigma(t) & s(t)"
116
+ python generate.py --outdir=out --steps=18 --solver=heun --disc=edm --schedule=linear --scaling=none \
117
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
118
+ ```
119
+
120
+ ## Calculating FID
121
+
122
+ To compute Fr&eacute;chet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`:
123
+
124
+ ```.bash
125
+ # Generate 50000 images and save them as fid-tmp/*/*.png
126
+ torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \
127
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
128
+
129
+ # Calculate FID
130
+ torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \
131
+ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
132
+ ```
133
+
134
+ Both of the above commands can be parallelized across multiple GPUs by adjusting `--nproc_per_node`. The second command typically takes 1-3 minutes in practice, but the first one can sometimes take several hours, depending on the configuration. See [`python fid.py --help`](./docs/fid-help.txt) for the full list of options.
135
+
136
+ Note that the numerical value of FID varies across different random seeds and is highly sensitive to the number of images. By default, `fid.py` will always use 50,000 generated images; providing fewer images will result in an error, whereas providing more will use a random subset. To reduce the effect of random variation, we recommend repeating the calculation multiple times with different seeds, e.g., `--seeds=0-49999`, `--seeds=50000-99999`, and `--seeds=100000-149999`. In our paper, we calculated each FID three times and reported the minimum.
137
+
138
+ Also note that it is important to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, we provide the exact reference statistics that correspond to our pre-trained models:
139
+
140
+ * [https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/)
141
+
142
+ For ImageNet, we provide two sets of reference statistics to enable apples-to-apples comparison: `imagenet-64x64.npz` should be used when evaluating the EDM model (`edm-imagenet-64x64-cond-adm.pkl`), whereas `imagenet-64x64-baseline.npz` should be used when evaluating the baseline model (`baseline-imagenet-64x64-cond-adm.pkl`); the latter was originally trained by Dhariwal and Nichol using slightly different training data.
143
+
144
+ You can compute the reference statistics for your own datasets as follows:
145
+
146
+ ```.bash
147
+ python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
148
+ ```
149
+
150
+ ## Preparing datasets
151
+
152
+ Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information.
153
+
154
+ **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
155
+
156
+ ```.bash
157
+ python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
158
+ --dest=datasets/cifar10-32x32.zip
159
+ python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
160
+ ```
161
+
162
+ **FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution:
163
+
164
+ ```.bash
165
+ python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
166
+ --dest=datasets/ffhq-64x64.zip --resolution=64x64
167
+ python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz
168
+ ```
169
+
170
+ **AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution:
171
+
172
+ ```.bash
173
+ python dataset_tool.py --source=downloads/afhqv2 \
174
+ --dest=datasets/afhqv2-64x64.zip --resolution=64x64
175
+ python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz
176
+ ```
177
+
178
+ **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution:
179
+
180
+ ```.bash
181
+ python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
182
+ --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
183
+ python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz
184
+ ```
185
+
186
+ ## Training new models
187
+
188
+ You can train new models using `train.py`. For example:
189
+
190
+ ```.bash
191
+ # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
192
+ torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \
193
+ --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
194
+ ```
195
+
196
+ The above example uses the default batch size of 512 images (controlled by `--batch`) that is divided evenly among 8 GPUs (controlled by `--nproc_per_node`) to yield 64 images per GPU. Training large models may run out of GPU memory; the best way to avoid this is to limit the per-GPU batch size, e.g., `--batch-gpu=32`. This employs gradient accumulation to yield the same results as using full per-GPU batches. See [`python train.py --help`](./docs/train-help.txt) for the full list of options.
197
+
198
+ The results of each training run are saved to a newly created directory, for example `training-runs/00000-cifar10-cond-ddpmpp-edm-gpus8-batch64-fp32`. The training loop exports network snapshots (`network-snapshot-*.pkl`) and training states (`training-state-*.pt`) at regular intervals (controlled by `--snap` and `--dump`). The network snapshots can be used to generate images with `generate.py`, and the training states can be used to resume the training later on (`--resume`). Other useful information is recorded in `log.txt` and `stats.jsonl`. To monitor training convergence, we recommend looking at the training loss (`"Loss/loss"` in `stats.jsonl`) as well as periodically evaluating FID for `network-snapshot-*.pkl` using `generate.py` and `fid.py`.
199
+
200
+ The following table lists the exact training configurations that we used to obtain our pre-trained models:
201
+
202
+ | <sub>Model</sub> | <sub>GPUs</sub> | <sub>Time</sub> | <sub>Options</sub>
203
+ | :-- | :-- | :-- | :--
204
+ | <sub>cifar10&#8209;32x32&#8209;cond&#8209;vp</sub> | <sub>8xV100</sub> | <sub>~2&nbsp;days</sub> | <sub>`--cond=1 --arch=ddpmpp`</sub>
205
+ | <sub>cifar10&#8209;32x32&#8209;cond&#8209;ve</sub> | <sub>8xV100</sub> | <sub>~2&nbsp;days</sub> | <sub>`--cond=1 --arch=ncsnpp`</sub>
206
+ | <sub>cifar10&#8209;32x32&#8209;uncond&#8209;vp</sub> | <sub>8xV100</sub> | <sub>~2&nbsp;days</sub> | <sub>`--cond=0 --arch=ddpmpp`</sub>
207
+ | <sub>cifar10&#8209;32x32&#8209;uncond&#8209;ve</sub> | <sub>8xV100</sub> | <sub>~2&nbsp;days</sub> | <sub>`--cond=0 --arch=ncsnpp`</sub>
208
+ | <sub>ffhq&#8209;64x64&#8209;uncond&#8209;vp</sub> | <sub>8xV100</sub> | <sub>~4&nbsp;days</sub> | <sub>`--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`</sub>
209
+ | <sub>ffhq&#8209;64x64&#8209;uncond&#8209;ve</sub> | <sub>8xV100</sub> | <sub>~4&nbsp;days</sub> | <sub>`--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`</sub>
210
+ | <sub>afhqv2&#8209;64x64&#8209;uncond&#8209;vp</sub> | <sub>8xV100</sub> | <sub>~4&nbsp;days</sub> | <sub>`--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`</sub>
211
+ | <sub>afhqv2&#8209;64x64&#8209;uncond&#8209;ve</sub> | <sub>8xV100</sub> | <sub>~4&nbsp;days</sub> | <sub>`--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`</sub>
212
+ | <sub>imagenet&#8209;64x64&#8209;cond&#8209;adm</sub> | <sub>32xA100</sub> | <sub>~13&nbsp;days</sub> | <sub>`--cond=1 --arch=adm --duration=2500 --batch=4096 --lr=1e-4 --ema=50 --dropout=0.10 --augment=0 --fp16=1 --ls=100 --tick=200`</sub>
213
+
214
+ For ImageNet-64, we ran the training on four NVIDIA DGX A100 nodes, each containing 8 Ampere GPUs with 80 GB of memory. To reduce the GPU memory requirements, we recommend either training the model with more GPUs or limiting the per-GPU batch size with `--batch-gpu`. To set up multi-node training, please consult the [torchrun documentation](https://pytorch.org/docs/stable/elastic/run.html).
215
+
216
+ ## License
217
+
218
+ Copyright &copy; 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
219
+
220
+ All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).
221
+
222
+ `baseline-cifar10-32x32-uncond-vp.pkl` and `baseline-cifar10-32x32-uncond-ve.pkl` are derived from the [pre-trained models](https://github.com/yang-song/score_sde_pytorch) by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. The models were originally shared under the [Apache 2.0 license](https://github.com/yang-song/score_sde_pytorch/blob/main/LICENSE).
223
+
224
+ `baseline-imagenet-64x64-cond-adm.pkl` is derived from the [pre-trained model](https://github.com/openai/guided-diffusion) by Prafulla Dhariwal and Alex Nichol. The model was originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
225
+
226
+ `imagenet-64x64-baseline.npz` is derived from the [precomputed reference statistics](https://github.com/openai/guided-diffusion/tree/main/evaluations) by Prafulla Dhariwal and Alex Nichol. The statistics were
227
+ originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
228
+
229
+ ## Citation
230
+
231
+ ```
232
+ @inproceedings{Karras2022edm,
233
+ author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
234
+ title = {Elucidating the Design Space of Diffusion-Based Generative Models},
235
+ booktitle = {Proc. NeurIPS},
236
+ year = {2022}
237
+ }
238
+ ```
239
+
240
+ ## Development
241
+
242
+ This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
243
+
244
+ ## Acknowledgments
245
+
246
+ We thank Jaakko Lehtinen, Ming-Yu Liu, Tuomas Kynk&auml;&auml;nniemi, Axel Sauer, Arash Vahdat, and Janne Hellsten for discussions and comments, and Tero Kuosmanen, Samuel Klenberg, and Janne Hellsten for maintaining our compute infrastructure.
edm/dataset_tool.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Tool for creating ZIP/PNG based datasets."""
9
+
10
+ import functools
11
+ import gzip
12
+ import io
13
+ import json
14
+ import os
15
+ import pickle
16
+ import re
17
+ import sys
18
+ import tarfile
19
+ import zipfile
20
+ from pathlib import Path
21
+ from typing import Callable, Optional, Tuple, Union
22
+ import click
23
+ import numpy as np
24
+ import PIL.Image
25
+ from tqdm import tqdm
26
+
27
+ #----------------------------------------------------------------------------
28
+ # Parse a 'M,N' or 'MxN' integer tuple.
29
+ # Example: '4x2' returns (4,2)
30
+
31
+ def parse_tuple(s: str) -> Tuple[int, int]:
32
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
33
+ if m:
34
+ return int(m.group(1)), int(m.group(2))
35
+ raise click.ClickException(f'cannot parse tuple {s}')
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def maybe_min(a: int, b: Optional[int]) -> int:
40
+ if b is not None:
41
+ return min(a, b)
42
+ return a
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def file_ext(name: Union[str, Path]) -> str:
47
+ return str(name).split('.')[-1]
48
+
49
+ #----------------------------------------------------------------------------
50
+
51
+ def is_image_ext(fname: Union[str, Path]) -> bool:
52
+ ext = file_ext(fname).lower()
53
+ return f'.{ext}' in PIL.Image.EXTENSION
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def open_image_folder(source_dir, *, max_images: Optional[int]):
58
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
59
+ arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
60
+ max_idx = maybe_min(len(input_images), max_images)
61
+
62
+ # Load labels.
63
+ labels = dict()
64
+ meta_fname = os.path.join(source_dir, 'dataset.json')
65
+ if os.path.isfile(meta_fname):
66
+ with open(meta_fname, 'r') as file:
67
+ data = json.load(file)['labels']
68
+ if data is not None:
69
+ labels = {x[0]: x[1] for x in data}
70
+
71
+ # No labels available => determine from top-level directory names.
72
+ if len(labels) == 0:
73
+ toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
74
+ toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
75
+ if len(toplevel_indices) > 1:
76
+ labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
77
+
78
+ def iterate_images():
79
+ for idx, fname in enumerate(input_images):
80
+ img = np.array(PIL.Image.open(fname))
81
+ yield dict(img=img, label=labels.get(arch_fnames.get(fname)))
82
+ if idx >= max_idx - 1:
83
+ break
84
+ return max_idx, iterate_images()
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def open_image_zip(source, *, max_images: Optional[int]):
89
+ with zipfile.ZipFile(source, mode='r') as z:
90
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
91
+ max_idx = maybe_min(len(input_images), max_images)
92
+
93
+ # Load labels.
94
+ labels = dict()
95
+ if 'dataset.json' in z.namelist():
96
+ with z.open('dataset.json', 'r') as file:
97
+ data = json.load(file)['labels']
98
+ if data is not None:
99
+ labels = {x[0]: x[1] for x in data}
100
+
101
+ def iterate_images():
102
+ with zipfile.ZipFile(source, mode='r') as z:
103
+ for idx, fname in enumerate(input_images):
104
+ with z.open(fname, 'r') as file:
105
+ img = np.array(PIL.Image.open(file))
106
+ yield dict(img=img, label=labels.get(fname))
107
+ if idx >= max_idx - 1:
108
+ break
109
+ return max_idx, iterate_images()
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
114
+ import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python
115
+ import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb
116
+
117
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
119
+
120
+ def iterate_images():
121
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
122
+ for idx, (_key, value) in enumerate(txn.cursor()):
123
+ try:
124
+ try:
125
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
126
+ if img is None:
127
+ raise IOError('cv2.imdecode failed')
128
+ img = img[:, :, ::-1] # BGR => RGB
129
+ except IOError:
130
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
131
+ yield dict(img=img, label=None)
132
+ if idx >= max_idx - 1:
133
+ break
134
+ except:
135
+ print(sys.exc_info()[1])
136
+
137
+ return max_idx, iterate_images()
138
+
139
+ #----------------------------------------------------------------------------
140
+
141
+ def open_cifar10(tarball: str, *, max_images: Optional[int]):
142
+ images = []
143
+ labels = []
144
+
145
+ with tarfile.open(tarball, 'r:gz') as tar:
146
+ for batch in range(1, 6):
147
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
148
+ with tar.extractfile(member) as file:
149
+ data = pickle.load(file, encoding='latin1')
150
+ images.append(data['data'].reshape(-1, 3, 32, 32))
151
+ labels.append(data['labels'])
152
+
153
+ images = np.concatenate(images)
154
+ labels = np.concatenate(labels)
155
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
156
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
157
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
158
+ assert np.min(images) == 0 and np.max(images) == 255
159
+ assert np.min(labels) == 0 and np.max(labels) == 9
160
+
161
+ max_idx = maybe_min(len(images), max_images)
162
+
163
+ def iterate_images():
164
+ for idx, img in enumerate(images):
165
+ yield dict(img=img, label=int(labels[idx]))
166
+ if idx >= max_idx - 1:
167
+ break
168
+
169
+ return max_idx, iterate_images()
170
+
171
+ #----------------------------------------------------------------------------
172
+
173
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
174
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
175
+ assert labels_gz != images_gz
176
+ images = []
177
+ labels = []
178
+
179
+ with gzip.open(images_gz, 'rb') as f:
180
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
181
+ with gzip.open(labels_gz, 'rb') as f:
182
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
183
+
184
+ images = images.reshape(-1, 28, 28)
185
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
186
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
187
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
188
+ assert np.min(images) == 0 and np.max(images) == 255
189
+ assert np.min(labels) == 0 and np.max(labels) == 9
190
+
191
+ max_idx = maybe_min(len(images), max_images)
192
+
193
+ def iterate_images():
194
+ for idx, img in enumerate(images):
195
+ yield dict(img=img, label=int(labels[idx]))
196
+ if idx >= max_idx - 1:
197
+ break
198
+
199
+ return max_idx, iterate_images()
200
+
201
+ #----------------------------------------------------------------------------
202
+
203
+ def make_transform(
204
+ transform: Optional[str],
205
+ output_width: Optional[int],
206
+ output_height: Optional[int]
207
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
208
+ def scale(width, height, img):
209
+ w = img.shape[1]
210
+ h = img.shape[0]
211
+ if width == w and height == h:
212
+ return img
213
+ img = PIL.Image.fromarray(img)
214
+ ww = width if width is not None else w
215
+ hh = height if height is not None else h
216
+ img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
217
+ return np.array(img)
218
+
219
+ def center_crop(width, height, img):
220
+ crop = np.min(img.shape[:2])
221
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
222
+ if img.ndim == 2:
223
+ img = img[:, :, np.newaxis].repeat(3, axis=2)
224
+ img = PIL.Image.fromarray(img, 'RGB')
225
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
226
+ return np.array(img)
227
+
228
+ def center_crop_wide(width, height, img):
229
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
230
+ if img.shape[1] < width or ch < height:
231
+ return None
232
+
233
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
234
+ if img.ndim == 2:
235
+ img = img[:, :, np.newaxis].repeat(3, axis=2)
236
+ img = PIL.Image.fromarray(img, 'RGB')
237
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
238
+ img = np.array(img)
239
+
240
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
241
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
242
+ return canvas
243
+
244
+ if transform is None:
245
+ return functools.partial(scale, output_width, output_height)
246
+ if transform == 'center-crop':
247
+ if output_width is None or output_height is None:
248
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
249
+ return functools.partial(center_crop, output_width, output_height)
250
+ if transform == 'center-crop-wide':
251
+ if output_width is None or output_height is None:
252
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
253
+ return functools.partial(center_crop_wide, output_width, output_height)
254
+ assert False, 'unknown transform'
255
+
256
+ #----------------------------------------------------------------------------
257
+
258
+ def open_dataset(source, *, max_images: Optional[int]):
259
+ if os.path.isdir(source):
260
+ if source.rstrip('/').endswith('_lmdb'):
261
+ return open_lmdb(source, max_images=max_images)
262
+ else:
263
+ return open_image_folder(source, max_images=max_images)
264
+ elif os.path.isfile(source):
265
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
266
+ return open_cifar10(source, max_images=max_images)
267
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
268
+ return open_mnist(source, max_images=max_images)
269
+ elif file_ext(source) == 'zip':
270
+ return open_image_zip(source, max_images=max_images)
271
+ else:
272
+ assert False, 'unknown archive type'
273
+ else:
274
+ raise click.ClickException(f'Missing input file or directory: {source}')
275
+
276
+ #----------------------------------------------------------------------------
277
+
278
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
279
+ dest_ext = file_ext(dest)
280
+
281
+ if dest_ext == 'zip':
282
+ if os.path.dirname(dest) != '':
283
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
284
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
285
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
286
+ zf.writestr(fname, data)
287
+ return '', zip_write_bytes, zf.close
288
+ else:
289
+ # If the output folder already exists, check that is is
290
+ # empty.
291
+ #
292
+ # Note: creating the output directory is not strictly
293
+ # necessary as folder_write_bytes() also mkdirs, but it's better
294
+ # to give an error message earlier in case the dest folder
295
+ # somehow cannot be created.
296
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
297
+ raise click.ClickException('--dest folder must be empty')
298
+ os.makedirs(dest, exist_ok=True)
299
+
300
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
301
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
302
+ with open(fname, 'wb') as fout:
303
+ if isinstance(data, str):
304
+ data = data.encode('utf8')
305
+ fout.write(data)
306
+ return dest, folder_write_bytes, lambda: None
307
+
308
+ #----------------------------------------------------------------------------
309
+
310
+ @click.command()
311
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
312
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
313
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
314
+ @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide']))
315
+ @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
316
+
317
+ def main(
318
+ source: str,
319
+ dest: str,
320
+ max_images: Optional[int],
321
+ transform: Optional[str],
322
+ resolution: Optional[Tuple[int, int]]
323
+ ):
324
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
325
+
326
+ The input dataset format is guessed from the --source argument:
327
+
328
+ \b
329
+ --source *_lmdb/ Load LSUN dataset
330
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
331
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
332
+ --source path/ Recursively load all images from path/
333
+ --source dataset.zip Recursively load all images from dataset.zip
334
+
335
+ Specifying the output format and path:
336
+
337
+ \b
338
+ --dest /path/to/dir Save output files under /path/to/dir
339
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
340
+
341
+ The output dataset format can be either an image folder or an uncompressed zip archive.
342
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
343
+ offer better training performance on network file systems.
344
+
345
+ Images within the dataset archive will be stored as uncompressed PNG.
346
+ Uncompresed PNGs can be efficiently decoded in the training loop.
347
+
348
+ Class labels are stored in a file called 'dataset.json' that is stored at the
349
+ dataset root folder. This file has the following structure:
350
+
351
+ \b
352
+ {
353
+ "labels": [
354
+ ["00000/img00000000.png",6],
355
+ ["00000/img00000001.png",9],
356
+ ... repeated for every image in the datase
357
+ ["00049/img00049999.png",1]
358
+ ]
359
+ }
360
+
361
+ If the 'dataset.json' file cannot be found, class labels are determined from
362
+ top-level directory names.
363
+
364
+ Image scale/crop and resolution requirements:
365
+
366
+ Output images must be square-shaped and they must all have the same power-of-two
367
+ dimensions.
368
+
369
+ To scale arbitrary input image size to a specific width and height, use the
370
+ --resolution option. Output resolution will be either the original
371
+ input resolution (if resolution was not specified) or the one specified with
372
+ --resolution option.
373
+
374
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
375
+ center crop transform on the input image. These options should be used with the
376
+ --resolution option. For example:
377
+
378
+ \b
379
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
380
+ --transform=center-crop-wide --resolution=512x384
381
+ """
382
+
383
+ PIL.Image.init()
384
+
385
+ if dest == '':
386
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
387
+
388
+ num_files, input_iter = open_dataset(source, max_images=max_images)
389
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
390
+
391
+ if resolution is None: resolution = (None, None)
392
+ transform_image = make_transform(transform, *resolution)
393
+
394
+ dataset_attrs = None
395
+
396
+ labels = []
397
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
398
+ idx_str = f'{idx:08d}'
399
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
400
+
401
+ # Apply crop and resize.
402
+ img = transform_image(image['img'])
403
+ if img is None:
404
+ continue
405
+
406
+ # Error check to require uniform image attributes across
407
+ # the whole dataset.
408
+ channels = img.shape[2] if img.ndim == 3 else 1
409
+ cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels}
410
+ if dataset_attrs is None:
411
+ dataset_attrs = cur_image_attrs
412
+ width = dataset_attrs['width']
413
+ height = dataset_attrs['height']
414
+ if width != height:
415
+ raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
416
+ if dataset_attrs['channels'] not in [1, 3]:
417
+ raise click.ClickException('Input images must be stored as RGB or grayscale')
418
+ if width != 2 ** int(np.floor(np.log2(width))):
419
+ raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
420
+ elif dataset_attrs != cur_image_attrs:
421
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
422
+ raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
423
+
424
+ # Save the image as an uncompressed PNG.
425
+ img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
426
+ image_bits = io.BytesIO()
427
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
428
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
429
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
430
+
431
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
432
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
433
+ close_dest()
434
+
435
+ #----------------------------------------------------------------------------
436
+
437
+ if __name__ == "__main__":
438
+ main()
439
+
440
+ #----------------------------------------------------------------------------
edm/dnnlib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ from .util import EasyDict, make_cache_dir_path
edm/dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import tempfile
27
+ import urllib
28
+ import urllib.request
29
+ import uuid
30
+
31
+ from distutils.util import strtobool
32
+ from typing import Any, List, Tuple, Union, Optional
33
+
34
+
35
+ # Util classes
36
+ # ------------------------------------------------------------------------------------------
37
+
38
+
39
+ class EasyDict(dict):
40
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41
+
42
+ def __getattr__(self, name: str) -> Any:
43
+ try:
44
+ return self[name]
45
+ except KeyError:
46
+ raise AttributeError(name)
47
+
48
+ def __setattr__(self, name: str, value: Any) -> None:
49
+ self[name] = value
50
+
51
+ def __delattr__(self, name: str) -> None:
52
+ del self[name]
53
+
54
+
55
+ class Logger(object):
56
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57
+
58
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59
+ self.file = None
60
+
61
+ if file_name is not None:
62
+ self.file = open(file_name, file_mode)
63
+
64
+ self.should_flush = should_flush
65
+ self.stdout = sys.stdout
66
+ self.stderr = sys.stderr
67
+
68
+ sys.stdout = self
69
+ sys.stderr = self
70
+
71
+ def __enter__(self) -> "Logger":
72
+ return self
73
+
74
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75
+ self.close()
76
+
77
+ def write(self, text: Union[str, bytes]) -> None:
78
+ """Write text to stdout (and a file) and optionally flush."""
79
+ if isinstance(text, bytes):
80
+ text = text.decode()
81
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82
+ return
83
+
84
+ if self.file is not None:
85
+ self.file.write(text)
86
+
87
+ self.stdout.write(text)
88
+
89
+ if self.should_flush:
90
+ self.flush()
91
+
92
+ def flush(self) -> None:
93
+ """Flush written text to both stdout and a file, if open."""
94
+ if self.file is not None:
95
+ self.file.flush()
96
+
97
+ self.stdout.flush()
98
+
99
+ def close(self) -> None:
100
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
101
+ self.flush()
102
+
103
+ # if using multiple loggers, prevent closing in wrong order
104
+ if sys.stdout is self:
105
+ sys.stdout = self.stdout
106
+ if sys.stderr is self:
107
+ sys.stderr = self.stderr
108
+
109
+ if self.file is not None:
110
+ self.file.close()
111
+ self.file = None
112
+
113
+
114
+ # Cache directories
115
+ # ------------------------------------------------------------------------------------------
116
+
117
+ _dnnlib_cache_dir = None
118
+
119
+ def set_cache_dir(path: str) -> None:
120
+ global _dnnlib_cache_dir
121
+ _dnnlib_cache_dir = path
122
+
123
+ def make_cache_dir_path(*paths: str) -> str:
124
+ if _dnnlib_cache_dir is not None:
125
+ return os.path.join(_dnnlib_cache_dir, *paths)
126
+ if 'DNNLIB_CACHE_DIR' in os.environ:
127
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128
+ if 'HOME' in os.environ:
129
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130
+ if 'USERPROFILE' in os.environ:
131
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133
+
134
+ # Small util functions
135
+ # ------------------------------------------------------------------------------------------
136
+
137
+
138
+ def format_time(seconds: Union[int, float]) -> str:
139
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140
+ s = int(np.rint(seconds))
141
+
142
+ if s < 60:
143
+ return "{0}s".format(s)
144
+ elif s < 60 * 60:
145
+ return "{0}m {1:02}s".format(s // 60, s % 60)
146
+ elif s < 24 * 60 * 60:
147
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148
+ else:
149
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150
+
151
+
152
+ def format_time_brief(seconds: Union[int, float]) -> str:
153
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154
+ s = int(np.rint(seconds))
155
+
156
+ if s < 60:
157
+ return "{0}s".format(s)
158
+ elif s < 60 * 60:
159
+ return "{0}m {1:02}s".format(s // 60, s % 60)
160
+ elif s < 24 * 60 * 60:
161
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162
+ else:
163
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164
+
165
+
166
+ def ask_yes_no(question: str) -> bool:
167
+ """Ask the user the question until the user inputs a valid answer."""
168
+ while True:
169
+ try:
170
+ print("{0} [y/n]".format(question))
171
+ return strtobool(input().lower())
172
+ except ValueError:
173
+ pass
174
+
175
+
176
+ def tuple_product(t: Tuple) -> Any:
177
+ """Calculate the product of the tuple elements."""
178
+ result = 1
179
+
180
+ for v in t:
181
+ result *= v
182
+
183
+ return result
184
+
185
+
186
+ _str_to_ctype = {
187
+ "uint8": ctypes.c_ubyte,
188
+ "uint16": ctypes.c_uint16,
189
+ "uint32": ctypes.c_uint32,
190
+ "uint64": ctypes.c_uint64,
191
+ "int8": ctypes.c_byte,
192
+ "int16": ctypes.c_int16,
193
+ "int32": ctypes.c_int32,
194
+ "int64": ctypes.c_int64,
195
+ "float32": ctypes.c_float,
196
+ "float64": ctypes.c_double
197
+ }
198
+
199
+
200
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202
+ type_str = None
203
+
204
+ if isinstance(type_obj, str):
205
+ type_str = type_obj
206
+ elif hasattr(type_obj, "__name__"):
207
+ type_str = type_obj.__name__
208
+ elif hasattr(type_obj, "name"):
209
+ type_str = type_obj.name
210
+ else:
211
+ raise RuntimeError("Cannot infer type name from input")
212
+
213
+ assert type_str in _str_to_ctype.keys()
214
+
215
+ my_dtype = np.dtype(type_str)
216
+ my_ctype = _str_to_ctype[type_str]
217
+
218
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219
+
220
+ return my_dtype, my_ctype
221
+
222
+
223
+ def is_pickleable(obj: Any) -> bool:
224
+ try:
225
+ with io.BytesIO() as stream:
226
+ pickle.dump(obj, stream)
227
+ return True
228
+ except:
229
+ return False
230
+
231
+
232
+ # Functionality to import modules/objects by name, and call functions by name
233
+ # ------------------------------------------------------------------------------------------
234
+
235
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236
+ """Searches for the underlying module behind the name to some python object.
237
+ Returns the module and the object name (original name with module part removed)."""
238
+
239
+ # allow convenience shorthands, substitute them by full names
240
+ obj_name = re.sub("^np.", "numpy.", obj_name)
241
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242
+
243
+ # list alternatives for (module_name, local_obj_name)
244
+ parts = obj_name.split(".")
245
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246
+
247
+ # try each alternative in turn
248
+ for module_name, local_obj_name in name_pairs:
249
+ try:
250
+ module = importlib.import_module(module_name) # may raise ImportError
251
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
252
+ return module, local_obj_name
253
+ except:
254
+ pass
255
+
256
+ # maybe some of the modules themselves contain errors?
257
+ for module_name, _local_obj_name in name_pairs:
258
+ try:
259
+ importlib.import_module(module_name) # may raise ImportError
260
+ except ImportError:
261
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262
+ raise
263
+
264
+ # maybe the requested attribute is missing?
265
+ for module_name, local_obj_name in name_pairs:
266
+ try:
267
+ module = importlib.import_module(module_name) # may raise ImportError
268
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
269
+ except ImportError:
270
+ pass
271
+
272
+ # we are out of luck, but we have no idea why
273
+ raise ImportError(obj_name)
274
+
275
+
276
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277
+ """Traverses the object name and returns the last (rightmost) python object."""
278
+ if obj_name == '':
279
+ return module
280
+ obj = module
281
+ for part in obj_name.split("."):
282
+ obj = getattr(obj, part)
283
+ return obj
284
+
285
+
286
+ def get_obj_by_name(name: str) -> Any:
287
+ """Finds the python object with the given name."""
288
+ module, obj_name = get_module_from_obj_name(name)
289
+ return get_obj_from_module(module, obj_name)
290
+
291
+
292
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293
+ """Finds the python object with the given name and calls it as a function."""
294
+ assert func_name is not None
295
+ func_obj = get_obj_by_name(func_name)
296
+ assert callable(func_obj)
297
+ return func_obj(*args, **kwargs)
298
+
299
+
300
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301
+ """Finds the python class with the given name and constructs it with the given arguments."""
302
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
303
+
304
+
305
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
306
+ """Get the directory path of the module containing the given object name."""
307
+ module, _ = get_module_from_obj_name(obj_name)
308
+ return os.path.dirname(inspect.getfile(module))
309
+
310
+
311
+ def is_top_level_function(obj: Any) -> bool:
312
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314
+
315
+
316
+ def get_top_level_function_name(obj: Any) -> str:
317
+ """Return the fully-qualified name of a top-level function."""
318
+ assert is_top_level_function(obj)
319
+ module = obj.__module__
320
+ if module == '__main__':
321
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322
+ return module + "." + obj.__name__
323
+
324
+
325
+ # File system helpers
326
+ # ------------------------------------------------------------------------------------------
327
+
328
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329
+ """List all files recursively in a given directory while ignoring given file and directory names.
330
+ Returns list of tuples containing both absolute and relative paths."""
331
+ assert os.path.isdir(dir_path)
332
+ base_name = os.path.basename(os.path.normpath(dir_path))
333
+
334
+ if ignores is None:
335
+ ignores = []
336
+
337
+ result = []
338
+
339
+ for root, dirs, files in os.walk(dir_path, topdown=True):
340
+ for ignore_ in ignores:
341
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342
+
343
+ # dirs need to be edited in-place
344
+ for d in dirs_to_remove:
345
+ dirs.remove(d)
346
+
347
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348
+
349
+ absolute_paths = [os.path.join(root, f) for f in files]
350
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351
+
352
+ if add_base_to_relative:
353
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354
+
355
+ assert len(absolute_paths) == len(relative_paths)
356
+ result += zip(absolute_paths, relative_paths)
357
+
358
+ return result
359
+
360
+
361
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362
+ """Takes in a list of tuples of (src, dst) paths and copies files.
363
+ Will create all necessary directories."""
364
+ for file in files:
365
+ target_dir_name = os.path.dirname(file[1])
366
+
367
+ # will create all intermediate-level directories
368
+ if not os.path.exists(target_dir_name):
369
+ os.makedirs(target_dir_name)
370
+
371
+ shutil.copyfile(file[0], file[1])
372
+
373
+
374
+ # URL helpers
375
+ # ------------------------------------------------------------------------------------------
376
+
377
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378
+ """Determine whether the given object is a valid URL string."""
379
+ if not isinstance(obj, str) or not "://" in obj:
380
+ return False
381
+ if allow_file_urls and obj.startswith('file://'):
382
+ return True
383
+ try:
384
+ res = requests.compat.urlparse(obj)
385
+ if not res.scheme or not res.netloc or not "." in res.netloc:
386
+ return False
387
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388
+ if not res.scheme or not res.netloc or not "." in res.netloc:
389
+ return False
390
+ except:
391
+ return False
392
+ return True
393
+
394
+
395
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396
+ """Download the given URL and return a binary-mode file object to access the data."""
397
+ assert num_attempts >= 1
398
+ assert not (return_filename and (not cache))
399
+
400
+ # Doesn't look like an URL scheme so interpret it as a local filename.
401
+ if not re.match('^[a-z]+://', url):
402
+ return url if return_filename else open(url, "rb")
403
+
404
+ # Handle file URLs. This code handles unusual file:// patterns that
405
+ # arise on Windows:
406
+ #
407
+ # file:///c:/foo.txt
408
+ #
409
+ # which would translate to a local '/c:/foo.txt' filename that's
410
+ # invalid. Drop the forward slash for such pathnames.
411
+ #
412
+ # If you touch this code path, you should test it on both Linux and
413
+ # Windows.
414
+ #
415
+ # Some internet resources suggest using urllib.request.url2pathname() but
416
+ # but that converts forward slashes to backslashes and this causes
417
+ # its own set of problems.
418
+ if url.startswith('file://'):
419
+ filename = urllib.parse.urlparse(url).path
420
+ if re.match(r'^/[a-zA-Z]:', filename):
421
+ filename = filename[1:]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ assert is_url(url)
425
+
426
+ # Lookup from cache.
427
+ if cache_dir is None:
428
+ cache_dir = make_cache_dir_path('downloads')
429
+
430
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431
+ if cache:
432
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433
+ if len(cache_files) == 1:
434
+ filename = cache_files[0]
435
+ return filename if return_filename else open(filename, "rb")
436
+
437
+ # Download.
438
+ url_name = None
439
+ url_data = None
440
+ with requests.Session() as session:
441
+ if verbose:
442
+ print("Downloading %s ..." % url, end="", flush=True)
443
+ for attempts_left in reversed(range(num_attempts)):
444
+ try:
445
+ with session.get(url) as res:
446
+ res.raise_for_status()
447
+ if len(res.content) == 0:
448
+ raise IOError("No data received")
449
+
450
+ if len(res.content) < 8192:
451
+ content_str = res.content.decode("utf-8")
452
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
453
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454
+ if len(links) == 1:
455
+ url = requests.compat.urljoin(url, links[0])
456
+ raise IOError("Google Drive virus checker nag")
457
+ if "Google Drive - Quota exceeded" in content_str:
458
+ raise IOError("Google Drive download quota exceeded -- please try again later")
459
+
460
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461
+ url_name = match[1] if match else url
462
+ url_data = res.content
463
+ if verbose:
464
+ print(" done")
465
+ break
466
+ except KeyboardInterrupt:
467
+ raise
468
+ except:
469
+ if not attempts_left:
470
+ if verbose:
471
+ print(" failed")
472
+ raise
473
+ if verbose:
474
+ print(".", end="", flush=True)
475
+
476
+ # Save to cache.
477
+ if cache:
478
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479
+ safe_name = safe_name[:min(len(safe_name), 128)]
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
edm/environment.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: edm
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python>=3.8, < 3.10 # package build failures on 3.10
7
+ - pip
8
+ - numpy>=1.20
9
+ - click>=8.0
10
+ - pillow>=8.3.1
11
+ - scipy>=1.7.1
12
+ - pytorch=1.12.1
13
+ - psutil
14
+ - requests
15
+ - tqdm
16
+ - imageio
17
+ - pip:
18
+ - imageio-ffmpeg>=0.4.3
19
+ - pyspng
edm/example.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Minimal standalone example to reproduce the main results from the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import tqdm
12
+ import pickle
13
+ import numpy as np
14
+ import torch
15
+ import PIL.Image
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def generate_image_grid(
21
+ network_pkl, dest_path,
22
+ seed=0, gridw=8, gridh=8, device=torch.device('cuda'),
23
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
24
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
25
+ ):
26
+ batch_size = gridw * gridh
27
+ torch.manual_seed(seed)
28
+
29
+ # Load network.
30
+ print(f'Loading network from "{network_pkl}"...')
31
+ with dnnlib.util.open_url(network_pkl) as f:
32
+ net = pickle.load(f)['ema'].to(device)
33
+
34
+ # Pick latents and labels.
35
+ print(f'Generating {batch_size} images...')
36
+ latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
37
+ class_labels = None
38
+ if net.label_dim:
39
+ class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
40
+
41
+ # Adjust noise levels based on what's supported by the network.
42
+ sigma_min = max(sigma_min, net.sigma_min)
43
+ sigma_max = min(sigma_max, net.sigma_max)
44
+
45
+ # Time step discretization.
46
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
47
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
48
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
49
+
50
+ # Main sampling loop.
51
+ x_next = latents.to(torch.float64) * t_steps[0]
52
+ for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
53
+ x_cur = x_next
54
+
55
+ # Increase noise temporarily.
56
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
57
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
58
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
59
+
60
+ # Euler step.
61
+ denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
62
+ d_cur = (x_hat - denoised) / t_hat
63
+ x_next = x_hat + (t_next - t_hat) * d_cur
64
+
65
+ # Apply 2nd order correction.
66
+ if i < num_steps - 1:
67
+ denoised = net(x_next, t_next, class_labels).to(torch.float64)
68
+ d_prime = (x_next - denoised) / t_next
69
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
70
+
71
+ # Save image grid.
72
+ print(f'Saving image grid to "{dest_path}"...')
73
+ image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
74
+ image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
75
+ image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
76
+ image = image.cpu().numpy()
77
+ PIL.Image.fromarray(image, 'RGB').save(dest_path)
78
+ print('Done.')
79
+
80
+ #----------------------------------------------------------------------------
81
+
82
+ def main():
83
+ model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
84
+ generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35
85
+ generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79
86
+ generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79
87
+ generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ if __name__ == "__main__":
92
+ main()
93
+
94
+ #----------------------------------------------------------------------------
edm/fid.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Script for calculating Frechet Inception Distance (FID)."""
9
+
10
+ import os
11
+ import click
12
+ import tqdm
13
+ import pickle
14
+ import numpy as np
15
+ import scipy.linalg
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import distributed as dist
19
+ from training import dataset
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def calculate_inception_stats(
24
+ image_path, num_expected=None, seed=0, max_batch_size=64,
25
+ num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
26
+ ):
27
+ # Rank 0 goes first.
28
+ if dist.get_rank() != 0:
29
+ torch.distributed.barrier()
30
+
31
+ # Load Inception-v3 model.
32
+ # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
33
+ dist.print0('Loading Inception-v3 model...')
34
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
35
+ detector_kwargs = dict(return_features=True)
36
+ feature_dim = 2048
37
+ with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
38
+ detector_net = pickle.load(f).to(device)
39
+
40
+ # List images.
41
+ dist.print0(f'Loading images from "{image_path}"...')
42
+ dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
43
+ if num_expected is not None and len(dataset_obj) < num_expected:
44
+ raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
45
+ if len(dataset_obj) < 2:
46
+ raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
47
+
48
+ # Other ranks follow.
49
+ if dist.get_rank() == 0:
50
+ torch.distributed.barrier()
51
+
52
+ # Divide images into batches.
53
+ num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
54
+ all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
55
+ rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
56
+ data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
57
+
58
+ # Accumulate statistics.
59
+ dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
60
+ mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
61
+ sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
62
+ for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
63
+ torch.distributed.barrier()
64
+ if images.shape[0] == 0:
65
+ continue
66
+ if images.shape[1] == 1:
67
+ images = images.repeat([1, 3, 1, 1])
68
+ features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
69
+ mu += features.sum(0)
70
+ sigma += features.T @ features
71
+
72
+ # Calculate grand totals.
73
+ torch.distributed.all_reduce(mu)
74
+ torch.distributed.all_reduce(sigma)
75
+ mu /= len(dataset_obj)
76
+ sigma -= mu.ger(mu) * len(dataset_obj)
77
+ sigma /= len(dataset_obj) - 1
78
+ return mu.cpu().numpy(), sigma.cpu().numpy()
79
+
80
+ #----------------------------------------------------------------------------
81
+
82
+ def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
83
+ m = np.square(mu - mu_ref).sum()
84
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
85
+ fid = m + np.trace(sigma + sigma_ref - s * 2)
86
+ return float(np.real(fid))
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @click.group()
91
+ def main():
92
+ """Calculate Frechet Inception Distance (FID).
93
+
94
+ Examples:
95
+
96
+ \b
97
+ # Generate 50000 images and save them as fid-tmp/*/*.png
98
+ torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
99
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
100
+
101
+ \b
102
+ # Calculate FID
103
+ torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
104
+ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
105
+
106
+ \b
107
+ # Compute dataset reference statistics
108
+ python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
109
+ """
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ @main.command()
114
+ @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
115
+ @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
116
+ @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
117
+ @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
118
+ @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
119
+
120
+ def calc(image_path, ref_path, num_expected, seed, batch):
121
+ """Calculate FID for a given set of images."""
122
+ torch.multiprocessing.set_start_method('spawn')
123
+ dist.init()
124
+
125
+ dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
126
+ ref = None
127
+ if dist.get_rank() == 0:
128
+ with dnnlib.util.open_url(ref_path) as f:
129
+ ref = dict(np.load(f))
130
+
131
+ mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
132
+ dist.print0('Calculating FID...')
133
+ if dist.get_rank() == 0:
134
+ fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
135
+ print(f'{fid:g}')
136
+ torch.distributed.barrier()
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ @main.command()
141
+ @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
142
+ @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
143
+ @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
144
+
145
+ def ref(dataset_path, dest_path, batch):
146
+ """Calculate dataset reference statistics needed by 'calc'."""
147
+ torch.multiprocessing.set_start_method('spawn')
148
+ dist.init()
149
+
150
+ mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
151
+ dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
152
+ if dist.get_rank() == 0:
153
+ if os.path.dirname(dest_path):
154
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
155
+ np.savez(dest_path, mu=mu, sigma=sigma)
156
+
157
+ torch.distributed.barrier()
158
+ dist.print0('Done.')
159
+
160
+ #----------------------------------------------------------------------------
161
+
162
+ if __name__ == "__main__":
163
+ main()
164
+
165
+ #----------------------------------------------------------------------------
edm/generate.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Generate random images using the techniques described in the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import os
12
+ import re
13
+ import click
14
+ import tqdm
15
+ import pickle
16
+ import numpy as np
17
+ import torch
18
+ import PIL.Image
19
+ import dnnlib
20
+ from torch_utils import distributed as dist
21
+
22
+ #----------------------------------------------------------------------------
23
+ # Proposed EDM sampler (Algorithm 2).
24
+
25
+ def edm_sampler(
26
+ net, latents, class_labels=None, randn_like=torch.randn_like,
27
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
28
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
29
+ ):
30
+ # Adjust noise levels based on what's supported by the network.
31
+ sigma_min = max(sigma_min, net.sigma_min)
32
+ sigma_max = min(sigma_max, net.sigma_max)
33
+
34
+ # Time step discretization.
35
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
36
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
37
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
38
+
39
+ # Main sampling loop.
40
+ x_next = latents.to(torch.float64) * t_steps[0]
41
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
42
+ x_cur = x_next
43
+
44
+ # Increase noise temporarily.
45
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
46
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
47
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
48
+
49
+ # Euler step.
50
+ denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
51
+ d_cur = (x_hat - denoised) / t_hat
52
+ x_next = x_hat + (t_next - t_hat) * d_cur
53
+
54
+ # Apply 2nd order correction.
55
+ if i < num_steps - 1:
56
+ denoised = net(x_next, t_next, class_labels).to(torch.float64)
57
+ d_prime = (x_next - denoised) / t_next
58
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
59
+
60
+ return x_next
61
+
62
+ #----------------------------------------------------------------------------
63
+ # Generalized ablation sampler, representing the superset of all sampling
64
+ # methods discussed in the paper.
65
+
66
+ def ablation_sampler(
67
+ net, latents, class_labels=None, randn_like=torch.randn_like,
68
+ num_steps=18, sigma_min=None, sigma_max=None, rho=7,
69
+ solver='heun', discretization='edm', schedule='linear', scaling='none',
70
+ epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
71
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
72
+ ):
73
+ assert solver in ['euler', 'heun']
74
+ assert discretization in ['vp', 've', 'iddpm', 'edm']
75
+ assert schedule in ['vp', 've', 'linear']
76
+ assert scaling in ['vp', 'none']
77
+
78
+ # Helper functions for VP & VE noise level schedules.
79
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
80
+ vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
81
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
82
+ ve_sigma = lambda t: t.sqrt()
83
+ ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
84
+ ve_sigma_inv = lambda sigma: sigma ** 2
85
+
86
+ # Select default noise level range based on the specified time step discretization.
87
+ if sigma_min is None:
88
+ vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
89
+ sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
90
+ if sigma_max is None:
91
+ vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
92
+ sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
93
+
94
+ # Adjust noise levels based on what's supported by the network.
95
+ sigma_min = max(sigma_min, net.sigma_min)
96
+ sigma_max = min(sigma_max, net.sigma_max)
97
+
98
+ # Compute corresponding betas for VP.
99
+ vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
100
+ vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
101
+
102
+ # Define time steps in terms of noise level.
103
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
104
+ if discretization == 'vp':
105
+ orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
106
+ sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
107
+ elif discretization == 've':
108
+ orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
109
+ sigma_steps = ve_sigma(orig_t_steps)
110
+ elif discretization == 'iddpm':
111
+ u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
112
+ alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
113
+ for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
114
+ u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
115
+ u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
116
+ sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
117
+ else:
118
+ assert discretization == 'edm'
119
+ sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
120
+
121
+ # Define noise level schedule.
122
+ if schedule == 'vp':
123
+ sigma = vp_sigma(vp_beta_d, vp_beta_min)
124
+ sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
125
+ sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
126
+ elif schedule == 've':
127
+ sigma = ve_sigma
128
+ sigma_deriv = ve_sigma_deriv
129
+ sigma_inv = ve_sigma_inv
130
+ else:
131
+ assert schedule == 'linear'
132
+ sigma = lambda t: t
133
+ sigma_deriv = lambda t: 1
134
+ sigma_inv = lambda sigma: sigma
135
+
136
+ # Define scaling schedule.
137
+ if scaling == 'vp':
138
+ s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
139
+ s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
140
+ else:
141
+ assert scaling == 'none'
142
+ s = lambda t: 1
143
+ s_deriv = lambda t: 0
144
+
145
+ # Compute final time steps based on the corresponding noise levels.
146
+ t_steps = sigma_inv(net.round_sigma(sigma_steps))
147
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
148
+
149
+ # Main sampling loop.
150
+ t_next = t_steps[0]
151
+ x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
152
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
153
+ x_cur = x_next
154
+
155
+ # Increase noise temporarily.
156
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
157
+ t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
158
+ x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
159
+
160
+ # Euler step.
161
+ h = t_next - t_hat
162
+ denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
163
+ d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
164
+ x_prime = x_hat + alpha * h * d_cur
165
+ t_prime = t_hat + alpha * h
166
+
167
+ # Apply 2nd order correction.
168
+ if solver == 'euler' or i == num_steps - 1:
169
+ x_next = x_hat + h * d_cur
170
+ else:
171
+ assert solver == 'heun'
172
+ denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
173
+ d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
174
+ x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
175
+
176
+ return x_next
177
+
178
+ #----------------------------------------------------------------------------
179
+ # Wrapper for torch.Generator that allows specifying a different random seed
180
+ # for each sample in a minibatch.
181
+
182
+ class StackedRandomGenerator:
183
+ def __init__(self, device, seeds):
184
+ super().__init__()
185
+ self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
186
+
187
+ def randn(self, size, **kwargs):
188
+ assert size[0] == len(self.generators)
189
+ return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
190
+
191
+ def randn_like(self, input):
192
+ return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
193
+
194
+ def randint(self, *args, size, **kwargs):
195
+ assert size[0] == len(self.generators)
196
+ return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
197
+
198
+ #----------------------------------------------------------------------------
199
+ # Parse a comma separated list of numbers or ranges and return a list of ints.
200
+ # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
201
+
202
+ def parse_int_list(s):
203
+ if isinstance(s, list): return s
204
+ ranges = []
205
+ range_re = re.compile(r'^(\d+)-(\d+)$')
206
+ for p in s.split(','):
207
+ m = range_re.match(p)
208
+ if m:
209
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
210
+ else:
211
+ ranges.append(int(p))
212
+ return ranges
213
+
214
+ #----------------------------------------------------------------------------
215
+
216
+ @click.command()
217
+ @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
218
+ @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
219
+ @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
220
+ @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
221
+ @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
222
+ @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
223
+
224
+ @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
225
+ @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
226
+ @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
227
+ @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
228
+ @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
229
+ @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
230
+ @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
231
+ @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
232
+
233
+ @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
234
+ @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
235
+ @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
236
+ @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
237
+
238
+ def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
239
+ """Generate random images using the techniques described in the paper
240
+ "Elucidating the Design Space of Diffusion-Based Generative Models".
241
+
242
+ Examples:
243
+
244
+ \b
245
+ # Generate 64 images and save them as out/*.png
246
+ python generate.py --outdir=out --seeds=0-63 --batch=64 \\
247
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
248
+
249
+ \b
250
+ # Generate 1024 images using 2 GPUs
251
+ torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
252
+ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
253
+ """
254
+ dist.init()
255
+ num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
256
+ all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
257
+ rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
258
+
259
+ # Rank 0 goes first.
260
+ if dist.get_rank() != 0:
261
+ torch.distributed.barrier()
262
+
263
+ # Load network.
264
+ dist.print0(f'Loading network from "{network_pkl}"...')
265
+ with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
266
+ net = pickle.load(f)['ema'].to(device)
267
+
268
+ # Other ranks follow.
269
+ if dist.get_rank() == 0:
270
+ torch.distributed.barrier()
271
+
272
+ # Loop over batches.
273
+ dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
274
+ for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
275
+ torch.distributed.barrier()
276
+ batch_size = len(batch_seeds)
277
+ if batch_size == 0:
278
+ continue
279
+
280
+ # Pick latents and labels.
281
+ rnd = StackedRandomGenerator(device, batch_seeds)
282
+ latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
283
+ class_labels = None
284
+ if net.label_dim:
285
+ class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
286
+ if class_idx is not None:
287
+ class_labels[:, :] = 0
288
+ class_labels[:, class_idx] = 1
289
+
290
+ # Generate images.
291
+ sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
292
+ have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
293
+ sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
294
+ images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
295
+
296
+ # Save images.
297
+ images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
298
+ for seed, image_np in zip(batch_seeds, images_np):
299
+ image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
300
+ os.makedirs(image_dir, exist_ok=True)
301
+ image_path = os.path.join(image_dir, f'{seed:06d}.png')
302
+ if image_np.shape[2] == 1:
303
+ PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
304
+ else:
305
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
306
+
307
+ # Done.
308
+ torch.distributed.barrier()
309
+ dist.print0('Done.')
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ if __name__ == "__main__":
314
+ main()
315
+
316
+ #----------------------------------------------------------------------------
edm/torch_utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ # empty
edm/torch_utils/distributed.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import os
9
+ import torch
10
+ from . import training_stats
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def init():
15
+ if 'MASTER_ADDR' not in os.environ:
16
+ os.environ['MASTER_ADDR'] = 'localhost'
17
+ if 'MASTER_PORT' not in os.environ:
18
+ os.environ['MASTER_PORT'] = '29500'
19
+ if 'RANK' not in os.environ:
20
+ os.environ['RANK'] = '0'
21
+ if 'LOCAL_RANK' not in os.environ:
22
+ os.environ['LOCAL_RANK'] = '0'
23
+ if 'WORLD_SIZE' not in os.environ:
24
+ os.environ['WORLD_SIZE'] = '1'
25
+
26
+ backend = 'gloo' if os.name == 'nt' else 'nccl'
27
+ torch.distributed.init_process_group(backend=backend, init_method='env://')
28
+ torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29
+
30
+ sync_device = torch.device('cuda') if get_world_size() > 1 else None
31
+ training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def get_rank():
36
+ return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37
+
38
+ #----------------------------------------------------------------------------
39
+
40
+ def get_world_size():
41
+ return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ def should_stop():
46
+ return False
47
+
48
+ #----------------------------------------------------------------------------
49
+
50
+ def update_progress(cur, total):
51
+ _ = cur, total
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def print0(*args, **kwargs):
56
+ if get_rank() == 0:
57
+ print(*args, **kwargs)
58
+
59
+ #----------------------------------------------------------------------------
edm/torch_utils/misc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import re
9
+ import contextlib
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+ import edm.dnnlib as dnnlib
14
+
15
+ #----------------------------------------------------------------------------
16
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17
+ # same constant is used multiple times.
18
+
19
+ _constant_cache = dict()
20
+
21
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22
+ value = np.asarray(value)
23
+ if shape is not None:
24
+ shape = tuple(shape)
25
+ if dtype is None:
26
+ dtype = torch.get_default_dtype()
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if memory_format is None:
30
+ memory_format = torch.contiguous_format
31
+
32
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33
+ tensor = _constant_cache.get(key, None)
34
+ if tensor is None:
35
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36
+ if shape is not None:
37
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38
+ tensor = tensor.contiguous(memory_format=memory_format)
39
+ _constant_cache[key] = tensor
40
+ return tensor
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Replace NaN/Inf with specified numerical values.
44
+
45
+ try:
46
+ nan_to_num = torch.nan_to_num # 1.8.0a0
47
+ except AttributeError:
48
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49
+ assert isinstance(input, torch.Tensor)
50
+ if posinf is None:
51
+ posinf = torch.finfo(input.dtype).max
52
+ if neginf is None:
53
+ neginf = torch.finfo(input.dtype).min
54
+ assert nan == 0
55
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56
+
57
+ #----------------------------------------------------------------------------
58
+ # Symbolic assert.
59
+
60
+ try:
61
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62
+ except AttributeError:
63
+ symbolic_assert = torch.Assert # 1.7.0
64
+
65
+ #----------------------------------------------------------------------------
66
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
67
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68
+
69
+ @contextlib.contextmanager
70
+ def suppress_tracer_warnings():
71
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72
+ warnings.filters.insert(0, flt)
73
+ yield
74
+ warnings.filters.remove(flt)
75
+
76
+ #----------------------------------------------------------------------------
77
+ # Assert that the shape of a tensor matches the given list of integers.
78
+ # None indicates that the size of a dimension is allowed to vary.
79
+ # Performs symbolic assertion when used in torch.jit.trace().
80
+
81
+ def assert_shape(tensor, ref_shape):
82
+ if tensor.ndim != len(ref_shape):
83
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85
+ if ref_size is None:
86
+ pass
87
+ elif isinstance(ref_size, torch.Tensor):
88
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
89
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90
+ elif isinstance(size, torch.Tensor):
91
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
92
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93
+ elif size != ref_size:
94
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95
+
96
+ #----------------------------------------------------------------------------
97
+ # Function decorator that calls torch.autograd.profiler.record_function().
98
+
99
+ def profiled_function(fn):
100
+ def decorator(*args, **kwargs):
101
+ with torch.autograd.profiler.record_function(fn.__name__):
102
+ return fn(*args, **kwargs)
103
+ decorator.__name__ = fn.__name__
104
+ return decorator
105
+
106
+ #----------------------------------------------------------------------------
107
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
108
+ # indefinitely, shuffling items as it goes.
109
+
110
+ class InfiniteSampler(torch.utils.data.Sampler):
111
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112
+ assert len(dataset) > 0
113
+ assert num_replicas > 0
114
+ assert 0 <= rank < num_replicas
115
+ assert 0 <= window_size <= 1
116
+ super().__init__(dataset)
117
+ self.dataset = dataset
118
+ self.rank = rank
119
+ self.num_replicas = num_replicas
120
+ self.shuffle = shuffle
121
+ self.seed = seed
122
+ self.window_size = window_size
123
+
124
+ def __iter__(self):
125
+ order = np.arange(len(self.dataset))
126
+ rnd = None
127
+ window = 0
128
+ if self.shuffle:
129
+ rnd = np.random.RandomState(self.seed)
130
+ rnd.shuffle(order)
131
+ window = int(np.rint(order.size * self.window_size))
132
+
133
+ idx = 0
134
+ while True:
135
+ i = idx % order.size
136
+ if idx % self.num_replicas == self.rank:
137
+ yield order[i]
138
+ if window >= 2:
139
+ j = (i - rnd.randint(window)) % order.size
140
+ order[i], order[j] = order[j], order[i]
141
+ idx += 1
142
+
143
+ #----------------------------------------------------------------------------
144
+ # Utilities for operating with torch.nn.Module parameters and buffers.
145
+
146
+ def params_and_buffers(module):
147
+ assert isinstance(module, torch.nn.Module)
148
+ return list(module.parameters()) + list(module.buffers())
149
+
150
+ def named_params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.named_parameters()) + list(module.named_buffers())
153
+
154
+ @torch.no_grad()
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = dict(named_params_and_buffers(src_module))
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ tensor.copy_(src_tensors[name])
163
+
164
+ #----------------------------------------------------------------------------
165
+ # Context manager for easily enabling/disabling DistributedDataParallel
166
+ # synchronization.
167
+
168
+ @contextlib.contextmanager
169
+ def ddp_sync(module, sync):
170
+ assert isinstance(module, torch.nn.Module)
171
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
+ yield
173
+ else:
174
+ with module.no_sync():
175
+ yield
176
+
177
+ #----------------------------------------------------------------------------
178
+ # Check DistributedDataParallel consistency across processes.
179
+
180
+ def check_ddp_consistency(module, ignore_regex=None):
181
+ assert isinstance(module, torch.nn.Module)
182
+ for name, tensor in named_params_and_buffers(module):
183
+ fullname = type(module).__name__ + '.' + name
184
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
+ continue
186
+ tensor = tensor.detach()
187
+ if tensor.is_floating_point():
188
+ tensor = nan_to_num(tensor)
189
+ other = tensor.clone()
190
+ torch.distributed.broadcast(tensor=other, src=0)
191
+ assert (tensor == other).all(), fullname
192
+
193
+ #----------------------------------------------------------------------------
194
+ # Print summary table of module hierarchy.
195
+
196
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197
+ assert isinstance(module, torch.nn.Module)
198
+ assert not isinstance(module, torch.jit.ScriptModule)
199
+ assert isinstance(inputs, (tuple, list))
200
+
201
+ # Register hooks.
202
+ entries = []
203
+ nesting = [0]
204
+ def pre_hook(_mod, _inputs):
205
+ nesting[0] += 1
206
+ def post_hook(mod, _inputs, outputs):
207
+ nesting[0] -= 1
208
+ if nesting[0] <= max_nesting:
209
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214
+
215
+ # Run module.
216
+ outputs = module(*inputs)
217
+ for hook in hooks:
218
+ hook.remove()
219
+
220
+ # Identify unique outputs, parameters, and buffers.
221
+ tensors_seen = set()
222
+ for e in entries:
223
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227
+
228
+ # Filter out redundant entries.
229
+ if skip_redundant:
230
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231
+
232
+ # Construct table.
233
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234
+ rows += [['---'] * len(rows[0])]
235
+ param_total = 0
236
+ buffer_total = 0
237
+ submodule_names = {mod: name for name, mod in module.named_modules()}
238
+ for e in entries:
239
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
240
+ param_size = sum(t.numel() for t in e.unique_params)
241
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
242
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
243
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244
+ rows += [[
245
+ name + (':0' if len(e.outputs) >= 2 else ''),
246
+ str(param_size) if param_size else '-',
247
+ str(buffer_size) if buffer_size else '-',
248
+ (output_shapes + ['-'])[0],
249
+ (output_dtypes + ['-'])[0],
250
+ ]]
251
+ for idx in range(1, len(e.outputs)):
252
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253
+ param_total += param_size
254
+ buffer_total += buffer_size
255
+ rows += [['---'] * len(rows[0])]
256
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257
+
258
+ # Print table.
259
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260
+ print()
261
+ for row in rows:
262
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263
+ print()
264
+ return outputs
265
+
266
+ #----------------------------------------------------------------------------
edm/torch_utils/persistence.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for pickling Python code alongside other data.
9
+
10
+ The pickled code is automatically imported into a separate Python module
11
+ during unpickling. This way, any previously exported pickles will remain
12
+ usable even if the original code is no longer available, or if the current
13
+ version of the code is not consistent with what was originally pickled."""
14
+
15
+ import sys
16
+ import pickle
17
+ import io
18
+ import inspect
19
+ import copy
20
+ import uuid
21
+ import types
22
+ import edm.dnnlib as dnnlib
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ _version = 6 # internal version number
27
+ _decorators = set() # {decorator_class, ...}
28
+ _import_hooks = [] # [hook_function, ...]
29
+ _module_to_src_dict = dict() # {module: src, ...}
30
+ _src_to_module_dict = dict() # {src: module, ...}
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def persistent_class(orig_class):
35
+ r"""Class decorator that extends a given class to save its source code
36
+ when pickled.
37
+
38
+ Example:
39
+
40
+ from torch_utils import persistence
41
+
42
+ @persistence.persistent_class
43
+ class MyNetwork(torch.nn.Module):
44
+ def __init__(self, num_inputs, num_outputs):
45
+ super().__init__()
46
+ self.fc = MyLayer(num_inputs, num_outputs)
47
+ ...
48
+
49
+ @persistence.persistent_class
50
+ class MyLayer(torch.nn.Module):
51
+ ...
52
+
53
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54
+ source code alongside other internal state (e.g., parameters, buffers,
55
+ and submodules). This way, any previously exported pickle will remain
56
+ usable even if the class definitions have been modified or are no
57
+ longer available.
58
+
59
+ The decorator saves the source code of the entire Python module
60
+ containing the decorated class. It does *not* save the source code of
61
+ any imported modules. Thus, the imported modules must be available
62
+ during unpickling, also including `torch_utils.persistence` itself.
63
+
64
+ It is ok to call functions defined in the same module from the
65
+ decorated class. However, if the decorated class depends on other
66
+ classes defined in the same module, they must be decorated as well.
67
+ This is illustrated in the above example in the case of `MyLayer`.
68
+
69
+ It is also possible to employ the decorator just-in-time before
70
+ calling the constructor. For example:
71
+
72
+ cls = MyLayer
73
+ if want_to_make_it_persistent:
74
+ cls = persistence.persistent_class(cls)
75
+ layer = cls(num_inputs, num_outputs)
76
+
77
+ As an additional feature, the decorator also keeps track of the
78
+ arguments that were used to construct each instance of the decorated
79
+ class. The arguments can be queried via `obj.init_args` and
80
+ `obj.init_kwargs`, and they are automatically pickled alongside other
81
+ object state. This feature can be disabled on a per-instance basis
82
+ by setting `self._record_init_args = False` in the constructor.
83
+
84
+ A typical use case is to first unpickle a previous instance of a
85
+ persistent class, and then upgrade it to use the latest version of
86
+ the source code:
87
+
88
+ with open('old_pickle.pkl', 'rb') as f:
89
+ old_net = pickle.load(f)
90
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92
+ """
93
+ assert isinstance(orig_class, type)
94
+ if is_persistent(orig_class):
95
+ return orig_class
96
+
97
+ assert orig_class.__module__ in sys.modules
98
+ orig_module = sys.modules[orig_class.__module__]
99
+ orig_module_src = _module_to_src(orig_module)
100
+
101
+ class Decorator(orig_class):
102
+ _orig_module_src = orig_module_src
103
+ _orig_class_name = orig_class.__name__
104
+
105
+ def __init__(self, *args, **kwargs):
106
+ super().__init__(*args, **kwargs)
107
+ record_init_args = getattr(self, '_record_init_args', True)
108
+ self._init_args = copy.deepcopy(args) if record_init_args else None
109
+ self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110
+ assert orig_class.__name__ in orig_module.__dict__
111
+ _check_pickleable(self.__reduce__())
112
+
113
+ @property
114
+ def init_args(self):
115
+ assert self._init_args is not None
116
+ return copy.deepcopy(self._init_args)
117
+
118
+ @property
119
+ def init_kwargs(self):
120
+ assert self._init_kwargs is not None
121
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122
+
123
+ def __reduce__(self):
124
+ fields = list(super().__reduce__())
125
+ fields += [None] * max(3 - len(fields), 0)
126
+ if fields[0] is not _reconstruct_persistent_obj:
127
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
129
+ fields[1] = (meta,) # reconstruct args
130
+ fields[2] = None # state dict
131
+ return tuple(fields)
132
+
133
+ Decorator.__name__ = orig_class.__name__
134
+ Decorator.__module__ = orig_class.__module__
135
+ _decorators.add(Decorator)
136
+ return Decorator
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ def is_persistent(obj):
141
+ r"""Test whether the given object or class is persistent, i.e.,
142
+ whether it will save its source code when pickled.
143
+ """
144
+ try:
145
+ if obj in _decorators:
146
+ return True
147
+ except TypeError:
148
+ pass
149
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ def import_hook(hook):
154
+ r"""Register an import hook that is called whenever a persistent object
155
+ is being unpickled. A typical use case is to patch the pickled source
156
+ code to avoid errors and inconsistencies when the API of some imported
157
+ module has changed.
158
+
159
+ The hook should have the following signature:
160
+
161
+ hook(meta) -> modified meta
162
+
163
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164
+
165
+ type: Type of the persistent object, e.g. `'class'`.
166
+ version: Internal version number of `torch_utils.persistence`.
167
+ module_src Original source code of the Python module.
168
+ class_name: Class name in the original Python module.
169
+ state: Internal state of the object.
170
+
171
+ Example:
172
+
173
+ @persistence.import_hook
174
+ def wreck_my_network(meta):
175
+ if meta.class_name == 'MyNetwork':
176
+ print('MyNetwork is being imported. I will wreck it!')
177
+ meta.module_src = meta.module_src.replace("True", "False")
178
+ return meta
179
+ """
180
+ assert callable(hook)
181
+ _import_hooks.append(hook)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ def _reconstruct_persistent_obj(meta):
186
+ r"""Hook that is called internally by the `pickle` module to unpickle
187
+ a persistent object.
188
+ """
189
+ meta = dnnlib.EasyDict(meta)
190
+ meta.state = dnnlib.EasyDict(meta.state)
191
+ for hook in _import_hooks:
192
+ meta = hook(meta)
193
+ assert meta is not None
194
+
195
+ assert meta.version == _version
196
+ module = _src_to_module(meta.module_src)
197
+
198
+ assert meta.type == 'class'
199
+ orig_class = module.__dict__[meta.class_name]
200
+ decorator_class = persistent_class(orig_class)
201
+ obj = decorator_class.__new__(decorator_class)
202
+
203
+ setstate = getattr(obj, '__setstate__', None)
204
+ if callable(setstate):
205
+ setstate(meta.state) # pylint: disable=not-callable
206
+ else:
207
+ obj.__dict__.update(meta.state)
208
+ return obj
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ def _module_to_src(module):
213
+ r"""Query the source code of a given Python module.
214
+ """
215
+ src = _module_to_src_dict.get(module, None)
216
+ if src is None:
217
+ src = inspect.getsource(module)
218
+ _module_to_src_dict[module] = src
219
+ _src_to_module_dict[src] = module
220
+ return src
221
+
222
+ def _src_to_module(src):
223
+ r"""Get or create a Python module for the given source code.
224
+ """
225
+ module = _src_to_module_dict.get(src, None)
226
+ if module is None:
227
+ module_name = "_imported_module_" + uuid.uuid4().hex
228
+ module = types.ModuleType(module_name)
229
+ sys.modules[module_name] = module
230
+ _module_to_src_dict[module] = src
231
+ _src_to_module_dict[src] = module
232
+ exec(src, module.__dict__) # pylint: disable=exec-used
233
+ return module
234
+
235
+ #----------------------------------------------------------------------------
236
+
237
+ def _check_pickleable(obj):
238
+ r"""Check that the given object is pickleable, raising an exception if
239
+ it is not. This function is expected to be considerably more efficient
240
+ than actually pickling the object.
241
+ """
242
+ def recurse(obj):
243
+ if isinstance(obj, (list, tuple, set)):
244
+ return [recurse(x) for x in obj]
245
+ if isinstance(obj, dict):
246
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
247
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248
+ return None # Python primitive types are pickleable.
249
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250
+ return None # NumPy arrays and PyTorch tensors are pickleable.
251
+ if is_persistent(obj):
252
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
253
+ return obj
254
+ with io.BytesIO() as f:
255
+ pickle.dump(recurse(obj), f)
256
+
257
+ #----------------------------------------------------------------------------
edm/torch_utils/training_stats.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for reporting and collecting training statistics across
9
+ multiple processes and devices. The interface is designed to minimize
10
+ synchronization overhead as well as the amount of boilerplate in user
11
+ code."""
12
+
13
+ import re
14
+ import numpy as np
15
+ import torch
16
+ import edm.dnnlib as dnnlib
17
+
18
+ from . import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
25
+ _rank = 0 # Rank of the current process.
26
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
27
+ _sync_called = False # Has _sync() been called yet?
28
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def init_multiprocessing(rank, sync_device):
34
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
35
+ across multiple processes.
36
+
37
+ This function must be called after
38
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
39
+ The call is not necessary if multi-process collection is not needed.
40
+
41
+ Args:
42
+ rank: Rank of the current process.
43
+ sync_device: PyTorch device to use for inter-process
44
+ communication, or None to disable multi-process
45
+ collection. Typically `torch.device('cuda', rank)`.
46
+ """
47
+ global _rank, _sync_device
48
+ assert not _sync_called
49
+ _rank = rank
50
+ _sync_device = sync_device
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ @misc.profiled_function
55
+ def report(name, value):
56
+ r"""Broadcasts the given set of scalars to all interested instances of
57
+ `Collector`, across device and process boundaries.
58
+
59
+ This function is expected to be extremely cheap and can be safely
60
+ called from anywhere in the training loop, loss function, or inside a
61
+ `torch.nn.Module`.
62
+
63
+ Warning: The current implementation expects the set of unique names to
64
+ be consistent across processes. Please make sure that `report()` is
65
+ called at least once for each unique name by each process, and in the
66
+ same order. If a given process has no scalars to broadcast, it can do
67
+ `report(name, [])` (empty list).
68
+
69
+ Args:
70
+ name: Arbitrary string specifying the name of the statistic.
71
+ Averages are accumulated separately for each unique name.
72
+ value: Arbitrary set of scalars. Can be a list, tuple,
73
+ NumPy array, PyTorch tensor, or Python scalar.
74
+
75
+ Returns:
76
+ The same `value` that was passed in.
77
+ """
78
+ if name not in _counters:
79
+ _counters[name] = dict()
80
+
81
+ elems = torch.as_tensor(value)
82
+ if elems.numel() == 0:
83
+ return value
84
+
85
+ elems = elems.detach().flatten().to(_reduce_dtype)
86
+ moments = torch.stack([
87
+ torch.ones_like(elems).sum(),
88
+ elems.sum(),
89
+ elems.square().sum(),
90
+ ])
91
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
92
+ moments = moments.to(_counter_dtype)
93
+
94
+ device = moments.device
95
+ if device not in _counters[name]:
96
+ _counters[name][device] = torch.zeros_like(moments)
97
+ _counters[name][device].add_(moments)
98
+ return value
99
+
100
+ #----------------------------------------------------------------------------
101
+
102
+ def report0(name, value):
103
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104
+ but ignores any scalars provided by the other processes.
105
+ See `report()` for further details.
106
+ """
107
+ report(name, value if _rank == 0 else [])
108
+ return value
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ class Collector:
113
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
114
+ computes their long-term averages (mean and standard deviation) over
115
+ user-defined periods of time.
116
+
117
+ The averages are first collected into internal counters that are not
118
+ directly visible to the user. They are then copied to the user-visible
119
+ state as a result of calling `update()` and can then be queried using
120
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121
+ internal counters for the next round, so that the user-visible state
122
+ effectively reflects averages collected between the last two calls to
123
+ `update()`.
124
+
125
+ Args:
126
+ regex: Regular expression defining which statistics to
127
+ collect. The default is to collect everything.
128
+ keep_previous: Whether to retain the previous averages if no
129
+ scalars were collected on a given round
130
+ (default: True).
131
+ """
132
+ def __init__(self, regex='.*', keep_previous=True):
133
+ self._regex = re.compile(regex)
134
+ self._keep_previous = keep_previous
135
+ self._cumulative = dict()
136
+ self._moments = dict()
137
+ self.update()
138
+ self._moments.clear()
139
+
140
+ def names(self):
141
+ r"""Returns the names of all statistics broadcasted so far that
142
+ match the regular expression specified at construction time.
143
+ """
144
+ return [name for name in _counters if self._regex.fullmatch(name)]
145
+
146
+ def update(self):
147
+ r"""Copies current values of the internal counters to the
148
+ user-visible state and resets them for the next round.
149
+
150
+ If `keep_previous=True` was specified at construction time, the
151
+ operation is skipped for statistics that have received no scalars
152
+ since the last update, retaining their previous averages.
153
+
154
+ This method performs a number of GPU-to-CPU transfers and one
155
+ `torch.distributed.all_reduce()`. It is intended to be called
156
+ periodically in the main training loop, typically once every
157
+ N training steps.
158
+ """
159
+ if not self._keep_previous:
160
+ self._moments.clear()
161
+ for name, cumulative in _sync(self.names()):
162
+ if name not in self._cumulative:
163
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164
+ delta = cumulative - self._cumulative[name]
165
+ self._cumulative[name].copy_(cumulative)
166
+ if float(delta[0]) != 0:
167
+ self._moments[name] = delta
168
+
169
+ def _get_delta(self, name):
170
+ r"""Returns the raw moments that were accumulated for the given
171
+ statistic between the last two calls to `update()`, or zero if
172
+ no scalars were collected.
173
+ """
174
+ assert self._regex.fullmatch(name)
175
+ if name not in self._moments:
176
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177
+ return self._moments[name]
178
+
179
+ def num(self, name):
180
+ r"""Returns the number of scalars that were accumulated for the given
181
+ statistic between the last two calls to `update()`, or zero if
182
+ no scalars were collected.
183
+ """
184
+ delta = self._get_delta(name)
185
+ return int(delta[0])
186
+
187
+ def mean(self, name):
188
+ r"""Returns the mean of the scalars that were accumulated for the
189
+ given statistic between the last two calls to `update()`, or NaN if
190
+ no scalars were collected.
191
+ """
192
+ delta = self._get_delta(name)
193
+ if int(delta[0]) == 0:
194
+ return float('nan')
195
+ return float(delta[1] / delta[0])
196
+
197
+ def std(self, name):
198
+ r"""Returns the standard deviation of the scalars that were
199
+ accumulated for the given statistic between the last two calls to
200
+ `update()`, or NaN if no scalars were collected.
201
+ """
202
+ delta = self._get_delta(name)
203
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204
+ return float('nan')
205
+ if int(delta[0]) == 1:
206
+ return float(0)
207
+ mean = float(delta[1] / delta[0])
208
+ raw_var = float(delta[2] / delta[0])
209
+ return np.sqrt(max(raw_var - np.square(mean), 0))
210
+
211
+ def as_dict(self):
212
+ r"""Returns the averages accumulated between the last two calls to
213
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214
+
215
+ dnnlib.EasyDict(
216
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217
+ ...
218
+ )
219
+ """
220
+ stats = dnnlib.EasyDict()
221
+ for name in self.names():
222
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223
+ return stats
224
+
225
+ def __getitem__(self, name):
226
+ r"""Convenience getter.
227
+ `collector[name]` is a synonym for `collector.mean(name)`.
228
+ """
229
+ return self.mean(name)
230
+
231
+ #----------------------------------------------------------------------------
232
+
233
+ def _sync(names):
234
+ r"""Synchronize the global cumulative counters across devices and
235
+ processes. Called internally by `Collector.update()`.
236
+ """
237
+ if len(names) == 0:
238
+ return []
239
+ global _sync_called
240
+ _sync_called = True
241
+
242
+ # Collect deltas within current rank.
243
+ deltas = []
244
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
245
+ for name in names:
246
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247
+ for counter in _counters[name].values():
248
+ delta.add_(counter.to(device))
249
+ counter.copy_(torch.zeros_like(counter))
250
+ deltas.append(delta)
251
+ deltas = torch.stack(deltas)
252
+
253
+ # Sum deltas across ranks.
254
+ if _sync_device is not None:
255
+ torch.distributed.all_reduce(deltas)
256
+
257
+ # Update cumulative values.
258
+ deltas = deltas.cpu()
259
+ for idx, name in enumerate(names):
260
+ if name not in _cumulative:
261
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262
+ _cumulative[name].add_(deltas[idx])
263
+
264
+ # Return name-value pairs.
265
+ return [(name, _cumulative[name]) for name in names]
266
+
267
+ #----------------------------------------------------------------------------
268
+ # Convenience.
269
+
270
+ default_collector = Collector()
271
+
272
+ #----------------------------------------------------------------------------
edm/train.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Train diffusion-based generative model using the techniques described in the
9
+ paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import os
12
+ import re
13
+ import json
14
+ import click
15
+ import torch
16
+ import dnnlib
17
+ from torch_utils import distributed as dist
18
+ from training import training_loop
19
+
20
+ import warnings
21
+ warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
22
+
23
+ #----------------------------------------------------------------------------
24
+ # Parse a comma separated list of numbers or ranges and return a list of ints.
25
+ # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
26
+
27
+ def parse_int_list(s):
28
+ if isinstance(s, list): return s
29
+ ranges = []
30
+ range_re = re.compile(r'^(\d+)-(\d+)$')
31
+ for p in s.split(','):
32
+ m = range_re.match(p)
33
+ if m:
34
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
35
+ else:
36
+ ranges.append(int(p))
37
+ return ranges
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ @click.command()
42
+
43
+ # Main options.
44
+ @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
45
+ @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
46
+ @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
47
+ @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
48
+ @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
49
+
50
+ # Hyperparameters.
51
+ @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
52
+ @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
53
+ @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
54
+ @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
55
+ @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
56
+ @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
57
+ @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
58
+ @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
59
+ @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
60
+ @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
61
+
62
+ # Performance-related.
63
+ @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
64
+ @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
65
+ @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
66
+ @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
67
+ @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
68
+
69
+ # I/O-related.
70
+ @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
71
+ @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
72
+ @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
73
+ @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
74
+ @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
75
+ @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
76
+ @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
77
+ @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
78
+ @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
79
+
80
+ def main(**kwargs):
81
+ """Train diffusion-based generative model using the techniques described in the
82
+ paper "Elucidating the Design Space of Diffusion-Based Generative Models".
83
+
84
+ Examples:
85
+
86
+ \b
87
+ # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
88
+ torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
89
+ --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
90
+ """
91
+ opts = dnnlib.EasyDict(kwargs)
92
+ torch.multiprocessing.set_start_method('spawn')
93
+ dist.init()
94
+
95
+ # Initialize config dict.
96
+ c = dnnlib.EasyDict()
97
+ c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
98
+ c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
99
+ c.network_kwargs = dnnlib.EasyDict()
100
+ c.loss_kwargs = dnnlib.EasyDict()
101
+ c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
102
+
103
+ # Validate dataset options.
104
+ try:
105
+ dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
106
+ dataset_name = dataset_obj.name
107
+ c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
108
+ c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
109
+ if opts.cond and not dataset_obj.has_labels:
110
+ raise click.ClickException('--cond=True requires labels specified in dataset.json')
111
+ del dataset_obj # conserve memory
112
+ except IOError as err:
113
+ raise click.ClickException(f'--data: {err}')
114
+
115
+ # Network architecture.
116
+ if opts.arch == 'ddpmpp':
117
+ c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
118
+ c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
119
+ elif opts.arch == 'ncsnpp':
120
+ c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
121
+ c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
122
+ else:
123
+ assert opts.arch == 'adm'
124
+ c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
125
+
126
+ # Preconditioning & loss function.
127
+ if opts.precond == 'vp':
128
+ c.network_kwargs.class_name = 'training.networks.VPPrecond'
129
+ c.loss_kwargs.class_name = 'training.loss.VPLoss'
130
+ elif opts.precond == 've':
131
+ c.network_kwargs.class_name = 'training.networks.VEPrecond'
132
+ c.loss_kwargs.class_name = 'training.loss.VELoss'
133
+ else:
134
+ assert opts.precond == 'edm'
135
+ c.network_kwargs.class_name = 'training.networks.EDMPrecond'
136
+ c.loss_kwargs.class_name = 'training.loss.EDMLoss'
137
+
138
+ # Network options.
139
+ if opts.cbase is not None:
140
+ c.network_kwargs.model_channels = opts.cbase
141
+ if opts.cres is not None:
142
+ c.network_kwargs.channel_mult = opts.cres
143
+ if opts.augment:
144
+ c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
145
+ c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
146
+ c.network_kwargs.augment_dim = 9
147
+ c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
148
+
149
+ # Training options.
150
+ c.total_kimg = max(int(opts.duration * 1000), 1)
151
+ c.ema_halflife_kimg = int(opts.ema * 1000)
152
+ c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
153
+ c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
154
+ c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
155
+
156
+ # Random seed.
157
+ if opts.seed is not None:
158
+ c.seed = opts.seed
159
+ else:
160
+ seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
161
+ torch.distributed.broadcast(seed, src=0)
162
+ c.seed = int(seed)
163
+
164
+ # Transfer learning and resume.
165
+ if opts.transfer is not None:
166
+ if opts.resume is not None:
167
+ raise click.ClickException('--transfer and --resume cannot be specified at the same time')
168
+ c.resume_pkl = opts.transfer
169
+ c.ema_rampup_ratio = None
170
+ elif opts.resume is not None:
171
+ match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
172
+ if not match or not os.path.isfile(opts.resume):
173
+ raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
174
+ c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
175
+ c.resume_kimg = int(match.group(1))
176
+ c.resume_state_dump = opts.resume
177
+
178
+ # Description string.
179
+ cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
180
+ dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
181
+ desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
182
+ if opts.desc is not None:
183
+ desc += f'-{opts.desc}'
184
+
185
+ # Pick output directory.
186
+ if dist.get_rank() != 0:
187
+ c.run_dir = None
188
+ elif opts.nosubdir:
189
+ c.run_dir = opts.outdir
190
+ else:
191
+ prev_run_dirs = []
192
+ if os.path.isdir(opts.outdir):
193
+ prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
194
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
195
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
196
+ cur_run_id = max(prev_run_ids, default=-1) + 1
197
+ c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
198
+ assert not os.path.exists(c.run_dir)
199
+
200
+ # Print options.
201
+ dist.print0()
202
+ dist.print0('Training options:')
203
+ dist.print0(json.dumps(c, indent=2))
204
+ dist.print0()
205
+ dist.print0(f'Output directory: {c.run_dir}')
206
+ dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
207
+ dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
208
+ dist.print0(f'Network architecture: {opts.arch}')
209
+ dist.print0(f'Preconditioning & loss: {opts.precond}')
210
+ dist.print0(f'Number of GPUs: {dist.get_world_size()}')
211
+ dist.print0(f'Batch size: {c.batch_size}')
212
+ dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
213
+ dist.print0()
214
+
215
+ # Dry run?
216
+ if opts.dry_run:
217
+ dist.print0('Dry run; exiting.')
218
+ return
219
+
220
+ # Create output directory.
221
+ dist.print0('Creating output directory...')
222
+ if dist.get_rank() == 0:
223
+ os.makedirs(c.run_dir, exist_ok=True)
224
+ with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
225
+ json.dump(c, f, indent=2)
226
+ dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
227
+
228
+ # Train.
229
+ training_loop.training_loop(**c)
230
+
231
+ #----------------------------------------------------------------------------
232
+
233
+ if __name__ == "__main__":
234
+ main()
235
+
236
+ #----------------------------------------------------------------------------
edm/training/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ # empty
edm/training/augment.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Augmentation pipeline used in the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models".
10
+ Built around the same concepts that were originally proposed in the paper
11
+ "Training Generative Adversarial Networks with Limited Data"."""
12
+
13
+ import numpy as np
14
+ import torch
15
+ from torch_utils import persistence
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Coefficients of various wavelet decomposition low-pass filters.
20
+
21
+ wavelets = {
22
+ 'haar': [0.7071067811865476, 0.7071067811865476],
23
+ 'db1': [0.7071067811865476, 0.7071067811865476],
24
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
27
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
28
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
29
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
30
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
31
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
32
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
33
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
34
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
35
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
36
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
37
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
38
+ }
39
+
40
+ #----------------------------------------------------------------------------
41
+ # Helpers for constructing transformation matrices.
42
+
43
+ def matrix(*rows, device=None):
44
+ assert all(len(row) == len(rows[0]) for row in rows)
45
+ elems = [x for row in rows for x in row]
46
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
47
+ if len(ref) == 0:
48
+ return misc.constant(np.asarray(rows), device=device)
49
+ assert device is None or device == ref[0].device
50
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
51
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
52
+
53
+ def translate2d(tx, ty, **kwargs):
54
+ return matrix(
55
+ [1, 0, tx],
56
+ [0, 1, ty],
57
+ [0, 0, 1],
58
+ **kwargs)
59
+
60
+ def translate3d(tx, ty, tz, **kwargs):
61
+ return matrix(
62
+ [1, 0, 0, tx],
63
+ [0, 1, 0, ty],
64
+ [0, 0, 1, tz],
65
+ [0, 0, 0, 1],
66
+ **kwargs)
67
+
68
+ def scale2d(sx, sy, **kwargs):
69
+ return matrix(
70
+ [sx, 0, 0],
71
+ [0, sy, 0],
72
+ [0, 0, 1],
73
+ **kwargs)
74
+
75
+ def scale3d(sx, sy, sz, **kwargs):
76
+ return matrix(
77
+ [sx, 0, 0, 0],
78
+ [0, sy, 0, 0],
79
+ [0, 0, sz, 0],
80
+ [0, 0, 0, 1],
81
+ **kwargs)
82
+
83
+ def rotate2d(theta, **kwargs):
84
+ return matrix(
85
+ [torch.cos(theta), torch.sin(-theta), 0],
86
+ [torch.sin(theta), torch.cos(theta), 0],
87
+ [0, 0, 1],
88
+ **kwargs)
89
+
90
+ def rotate3d(v, theta, **kwargs):
91
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
+ return matrix(
94
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
+ [0, 0, 0, 1],
98
+ **kwargs)
99
+
100
+ def translate2d_inv(tx, ty, **kwargs):
101
+ return translate2d(-tx, -ty, **kwargs)
102
+
103
+ def scale2d_inv(sx, sy, **kwargs):
104
+ return scale2d(1 / sx, 1 / sy, **kwargs)
105
+
106
+ def rotate2d_inv(theta, **kwargs):
107
+ return rotate2d(-theta, **kwargs)
108
+
109
+ #----------------------------------------------------------------------------
110
+ # Augmentation pipeline main class.
111
+ # All augmentations are disabled by default; individual augmentations can
112
+ # be enabled by setting their probability multipliers to 1.
113
+
114
+ @persistence.persistent_class
115
+ class AugmentPipe:
116
+ def __init__(self, p=1,
117
+ xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125,
118
+ scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125,
119
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
120
+ ):
121
+ super().__init__()
122
+ self.p = float(p) # Overall multiplier for augmentation probability.
123
+
124
+ # Pixel blitting.
125
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
126
+ self.yflip = float(yflip) # Probability multiplier for y-flip.
127
+ self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation.
128
+ self.translate_int = float(translate_int) # Probability multiplier for integer translation.
129
+ self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions.
130
+
131
+ # Geometric transformations.
132
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
133
+ self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation.
134
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
135
+ self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation.
136
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
137
+ self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle.
138
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
139
+ self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
140
+ self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions.
141
+
142
+ # Color transformations.
143
+ self.brightness = float(brightness) # Probability multiplier for brightness.
144
+ self.contrast = float(contrast) # Probability multiplier for contrast.
145
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
146
+ self.hue = float(hue) # Probability multiplier for hue rotation.
147
+ self.saturation = float(saturation) # Probability multiplier for saturation.
148
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
149
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
150
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
151
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
152
+
153
+ def __call__(self, images):
154
+ N, C, H, W = images.shape
155
+ device = images.device
156
+ labels = [torch.zeros([images.shape[0], 0], device=device)]
157
+
158
+ # ---------------
159
+ # Pixel blitting.
160
+ # ---------------
161
+
162
+ if self.xflip > 0:
163
+ w = torch.randint(2, [N, 1, 1, 1], device=device)
164
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w))
165
+ images = torch.where(w == 1, images.flip(3), images)
166
+ labels += [w]
167
+
168
+ if self.yflip > 0:
169
+ w = torch.randint(2, [N, 1, 1, 1], device=device)
170
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w))
171
+ images = torch.where(w == 1, images.flip(2), images)
172
+ labels += [w]
173
+
174
+ if self.rotate_int > 0:
175
+ w = torch.randint(4, [N, 1, 1, 1], device=device)
176
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w))
177
+ images = torch.where((w == 1) | (w == 2), images.flip(3), images)
178
+ images = torch.where((w == 2) | (w == 3), images.flip(2), images)
179
+ images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images)
180
+ labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)]
181
+
182
+ if self.translate_int > 0:
183
+ w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1
184
+ w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w))
185
+ tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64)
186
+ ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64)
187
+ b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij')
188
+ x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs()
189
+ y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs()
190
+ images = images.flatten()[(((b * C) + c) * H + y) * W + x]
191
+ labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)]
192
+
193
+ # ------------------------------------------------
194
+ # Select parameters for geometric transformations.
195
+ # ------------------------------------------------
196
+
197
+ I_3 = torch.eye(3, device=device)
198
+ G_inv = I_3
199
+
200
+ if self.scale > 0:
201
+ w = torch.randn([N], device=device)
202
+ w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w))
203
+ s = w.mul(self.scale_std).exp2()
204
+ G_inv = G_inv @ scale2d_inv(s, s)
205
+ labels += [w]
206
+
207
+ if self.rotate_frac > 0:
208
+ w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max)
209
+ w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w))
210
+ G_inv = G_inv @ rotate2d_inv(-w)
211
+ labels += [w.cos() - 1, w.sin()]
212
+
213
+ if self.aniso > 0:
214
+ w = torch.randn([N], device=device)
215
+ r = (torch.rand([N], device=device) * 2 - 1) * np.pi
216
+ w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w))
217
+ r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r))
218
+ s = w.mul(self.aniso_std).exp2()
219
+ G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r)
220
+ labels += [w * r.cos(), w * r.sin()]
221
+
222
+ if self.translate_frac > 0:
223
+ w = torch.randn([2, N], device=device)
224
+ w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w))
225
+ G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std))
226
+ labels += [w[0], w[1]]
227
+
228
+ # ----------------------------------
229
+ # Execute geometric transformations.
230
+ # ----------------------------------
231
+
232
+ if G_inv is not I_3:
233
+ cx = (W - 1) / 2
234
+ cy = (H - 1) / 2
235
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
236
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
237
+ Hz = np.asarray(wavelets['sym6'], dtype=np.float32)
238
+ Hz_pad = len(Hz) // 4
239
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
240
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
241
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
242
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
243
+ margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device))
244
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
245
+
246
+ # Pad image and adjust origin.
247
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
248
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
249
+
250
+ # Upsample.
251
+ conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
252
+ conv_pad = (len(Hz) + 1) // 2
253
+ images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1]
254
+ images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad])
255
+ images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :]
256
+ images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0])
257
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
258
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
259
+
260
+ # Execute transformation.
261
+ shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2]
262
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
263
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
264
+ images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
265
+
266
+ # Downsample and crop.
267
+ conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
268
+ conv_pad = (len(Hz) - 1) // 2
269
+ images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad]
270
+ images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :]
271
+
272
+ # --------------------------------------------
273
+ # Select parameters for color transformations.
274
+ # --------------------------------------------
275
+
276
+ I_4 = torch.eye(4, device=device)
277
+ M = I_4
278
+ luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device)
279
+
280
+ if self.brightness > 0:
281
+ w = torch.randn([N], device=device)
282
+ w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w))
283
+ b = w * self.brightness_std
284
+ M = translate3d(b, b, b) @ M
285
+ labels += [w]
286
+
287
+ if self.contrast > 0:
288
+ w = torch.randn([N], device=device)
289
+ w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w))
290
+ c = w.mul(self.contrast_std).exp2()
291
+ M = scale3d(c, c, c) @ M
292
+ labels += [w]
293
+
294
+ if self.lumaflip > 0:
295
+ w = torch.randint(2, [N, 1, 1], device=device)
296
+ w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w))
297
+ M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M
298
+ labels += [w]
299
+
300
+ if self.hue > 0:
301
+ w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max)
302
+ w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w))
303
+ M = rotate3d(luma_axis, w) @ M
304
+ labels += [w.cos() - 1, w.sin()]
305
+
306
+ if self.saturation > 0:
307
+ w = torch.randn([N, 1, 1], device=device)
308
+ w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w))
309
+ M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M
310
+ labels += [w]
311
+
312
+ # ------------------------------
313
+ # Execute color transformations.
314
+ # ------------------------------
315
+
316
+ if M is not I_4:
317
+ images = images.reshape([N, C, H * W])
318
+ if C == 3:
319
+ images = M[:, :3, :3] @ images + M[:, :3, 3:]
320
+ elif C == 1:
321
+ M = M[:, :3, :].mean(dim=1, keepdims=True)
322
+ images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:]
323
+ else:
324
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
325
+ images = images.reshape([N, C, H, W])
326
+
327
+ labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1)
328
+ return images, labels
329
+
330
+ #----------------------------------------------------------------------------
edm/training/dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Streaming images and labels from datasets created with dataset_tool.py."""
9
+
10
+ import os
11
+ import numpy as np
12
+ import zipfile
13
+ import PIL.Image
14
+ import json
15
+ import torch
16
+ import dnnlib
17
+
18
+ try:
19
+ import pyspng
20
+ except ImportError:
21
+ pyspng = None
22
+
23
+ #----------------------------------------------------------------------------
24
+ # Abstract base class for datasets.
25
+
26
+ class Dataset(torch.utils.data.Dataset):
27
+ def __init__(self,
28
+ name, # Name of the dataset.
29
+ raw_shape, # Shape of the raw image data (NCHW).
30
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
31
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
32
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
33
+ random_seed = 0, # Random seed to use when applying max_size.
34
+ cache = False, # Cache images in CPU memory?
35
+ ):
36
+ self._name = name
37
+ self._raw_shape = list(raw_shape)
38
+ self._use_labels = use_labels
39
+ self._cache = cache
40
+ self._cached_images = dict() # {raw_idx: np.ndarray, ...}
41
+ self._raw_labels = None
42
+ self._label_shape = None
43
+
44
+ # Apply max_size.
45
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
46
+ if (max_size is not None) and (self._raw_idx.size > max_size):
47
+ np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
48
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
49
+
50
+ # Apply xflip.
51
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
52
+ if xflip:
53
+ self._raw_idx = np.tile(self._raw_idx, 2)
54
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
55
+
56
+ def _get_raw_labels(self):
57
+ if self._raw_labels is None:
58
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
59
+ if self._raw_labels is None:
60
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
61
+ assert isinstance(self._raw_labels, np.ndarray)
62
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
63
+ assert self._raw_labels.dtype in [np.float32, np.int64]
64
+ if self._raw_labels.dtype == np.int64:
65
+ assert self._raw_labels.ndim == 1
66
+ assert np.all(self._raw_labels >= 0)
67
+ return self._raw_labels
68
+
69
+ def close(self): # to be overridden by subclass
70
+ pass
71
+
72
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
73
+ raise NotImplementedError
74
+
75
+ def _load_raw_labels(self): # to be overridden by subclass
76
+ raise NotImplementedError
77
+
78
+ def __getstate__(self):
79
+ return dict(self.__dict__, _raw_labels=None)
80
+
81
+ def __del__(self):
82
+ try:
83
+ self.close()
84
+ except:
85
+ pass
86
+
87
+ def __len__(self):
88
+ return self._raw_idx.size
89
+
90
+ def __getitem__(self, idx):
91
+ raw_idx = self._raw_idx[idx]
92
+ image = self._cached_images.get(raw_idx, None)
93
+ if image is None:
94
+ image = self._load_raw_image(raw_idx)
95
+ if self._cache:
96
+ self._cached_images[raw_idx] = image
97
+ assert isinstance(image, np.ndarray)
98
+ assert list(image.shape) == self.image_shape
99
+ assert image.dtype == np.uint8
100
+ if self._xflip[idx]:
101
+ assert image.ndim == 3 # CHW
102
+ image = image[:, :, ::-1]
103
+ return image.copy(), self.get_label(idx)
104
+
105
+ def get_label(self, idx):
106
+ label = self._get_raw_labels()[self._raw_idx[idx]]
107
+ if label.dtype == np.int64:
108
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
109
+ onehot[label] = 1
110
+ label = onehot
111
+ return label.copy()
112
+
113
+ def get_details(self, idx):
114
+ d = dnnlib.EasyDict()
115
+ d.raw_idx = int(self._raw_idx[idx])
116
+ d.xflip = (int(self._xflip[idx]) != 0)
117
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
118
+ return d
119
+
120
+ @property
121
+ def name(self):
122
+ return self._name
123
+
124
+ @property
125
+ def image_shape(self):
126
+ return list(self._raw_shape[1:])
127
+
128
+ @property
129
+ def num_channels(self):
130
+ assert len(self.image_shape) == 3 # CHW
131
+ return self.image_shape[0]
132
+
133
+ @property
134
+ def resolution(self):
135
+ assert len(self.image_shape) == 3 # CHW
136
+ assert self.image_shape[1] == self.image_shape[2]
137
+ return self.image_shape[1]
138
+
139
+ @property
140
+ def label_shape(self):
141
+ if self._label_shape is None:
142
+ raw_labels = self._get_raw_labels()
143
+ if raw_labels.dtype == np.int64:
144
+ self._label_shape = [int(np.max(raw_labels)) + 1]
145
+ else:
146
+ self._label_shape = raw_labels.shape[1:]
147
+ return list(self._label_shape)
148
+
149
+ @property
150
+ def label_dim(self):
151
+ assert len(self.label_shape) == 1
152
+ return self.label_shape[0]
153
+
154
+ @property
155
+ def has_labels(self):
156
+ return any(x != 0 for x in self.label_shape)
157
+
158
+ @property
159
+ def has_onehot_labels(self):
160
+ return self._get_raw_labels().dtype == np.int64
161
+
162
+ #----------------------------------------------------------------------------
163
+ # Dataset subclass that loads images recursively from the specified directory
164
+ # or ZIP file.
165
+
166
+ class ImageFolderDataset(Dataset):
167
+ def __init__(self,
168
+ path, # Path to directory or zip.
169
+ resolution = None, # Ensure specific resolution, None = highest available.
170
+ use_pyspng = True, # Use pyspng if available?
171
+ **super_kwargs, # Additional arguments for the Dataset base class.
172
+ ):
173
+ self._path = path
174
+ self._use_pyspng = use_pyspng
175
+ self._zipfile = None
176
+
177
+ if os.path.isdir(self._path):
178
+ self._type = 'dir'
179
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
180
+ elif self._file_ext(self._path) == '.zip':
181
+ self._type = 'zip'
182
+ self._all_fnames = set(self._get_zipfile().namelist())
183
+ else:
184
+ raise IOError('Path must point to a directory or zip')
185
+
186
+ PIL.Image.init()
187
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
188
+ if len(self._image_fnames) == 0:
189
+ raise IOError('No image files found in the specified path')
190
+
191
+ name = os.path.splitext(os.path.basename(self._path))[0]
192
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
193
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
194
+ raise IOError('Image files do not match the specified resolution')
195
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
196
+
197
+ @staticmethod
198
+ def _file_ext(fname):
199
+ return os.path.splitext(fname)[1].lower()
200
+
201
+ def _get_zipfile(self):
202
+ assert self._type == 'zip'
203
+ if self._zipfile is None:
204
+ self._zipfile = zipfile.ZipFile(self._path)
205
+ return self._zipfile
206
+
207
+ def _open_file(self, fname):
208
+ if self._type == 'dir':
209
+ return open(os.path.join(self._path, fname), 'rb')
210
+ if self._type == 'zip':
211
+ return self._get_zipfile().open(fname, 'r')
212
+ return None
213
+
214
+ def close(self):
215
+ try:
216
+ if self._zipfile is not None:
217
+ self._zipfile.close()
218
+ finally:
219
+ self._zipfile = None
220
+
221
+ def __getstate__(self):
222
+ return dict(super().__getstate__(), _zipfile=None)
223
+
224
+ def _load_raw_image(self, raw_idx):
225
+ fname = self._image_fnames[raw_idx]
226
+ with self._open_file(fname) as f:
227
+ if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
228
+ image = pyspng.load(f.read())
229
+ else:
230
+ image = np.array(PIL.Image.open(f))
231
+ if image.ndim == 2:
232
+ image = image[:, :, np.newaxis] # HW => HWC
233
+ image = image.transpose(2, 0, 1) # HWC => CHW
234
+ return image
235
+
236
+ def _load_raw_labels(self):
237
+ fname = 'dataset.json'
238
+ if fname not in self._all_fnames:
239
+ return None
240
+ with self._open_file(fname) as f:
241
+ labels = json.load(f)['labels']
242
+ if labels is None:
243
+ return None
244
+ labels = dict(labels)
245
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
246
+ labels = np.array(labels)
247
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
248
+ return labels
249
+
250
+ #----------------------------------------------------------------------------
edm/training/loss.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Loss functions used in the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import torch
12
+ from edm.torch_utils import persistence
13
+
14
+ #----------------------------------------------------------------------------
15
+ # Loss function corresponding to the variance preserving (VP) formulation
16
+ # from the paper "Score-Based Generative Modeling through Stochastic
17
+ # Differential Equations".
18
+
19
+ @persistence.persistent_class
20
+ class VPLoss:
21
+ def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
22
+ self.beta_d = beta_d
23
+ self.beta_min = beta_min
24
+ self.epsilon_t = epsilon_t
25
+
26
+ def __call__(self, net, images, labels, augment_pipe=None):
27
+ rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
28
+ sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
29
+ weight = 1 / sigma ** 2
30
+ y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
31
+ n = torch.randn_like(y) * sigma
32
+ D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
33
+ loss = weight * ((D_yn - y) ** 2)
34
+ return loss
35
+
36
+ def sigma(self, t):
37
+ t = torch.as_tensor(t)
38
+ return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
39
+
40
+ #----------------------------------------------------------------------------
41
+ # Loss function corresponding to the variance exploding (VE) formulation
42
+ # from the paper "Score-Based Generative Modeling through Stochastic
43
+ # Differential Equations".
44
+
45
+ @persistence.persistent_class
46
+ class VELoss:
47
+ def __init__(self, sigma_min=0.02, sigma_max=100):
48
+ self.sigma_min = sigma_min
49
+ self.sigma_max = sigma_max
50
+
51
+ def __call__(self, net, images, labels, augment_pipe=None):
52
+ rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
53
+ sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
54
+ weight = 1 / sigma ** 2
55
+ y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
56
+ n = torch.randn_like(y) * sigma
57
+ D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
58
+ loss = weight * ((D_yn - y) ** 2)
59
+ return loss
60
+
61
+ #----------------------------------------------------------------------------
62
+ # Improved loss function proposed in the paper "Elucidating the Design Space
63
+ # of Diffusion-Based Generative Models" (EDM).
64
+
65
+ @persistence.persistent_class
66
+ class EDMLoss:
67
+ def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
68
+ self.P_mean = P_mean
69
+ self.P_std = P_std
70
+ self.sigma_data = sigma_data
71
+
72
+ def __call__(self, net, images, labels=None, augment_pipe=None):
73
+ rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
74
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
75
+ weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
76
+ y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
77
+ n = torch.randn_like(y) * sigma
78
+ D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
79
+ loss = weight * ((D_yn - y) ** 2)
80
+ return loss
81
+
82
+ #----------------------------------------------------------------------------
edm/training/networks.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Model architectures and preconditioning schemes used in the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch_utils import persistence
14
+ from torch.nn.functional import silu
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Unified routine for initializing weights and biases.
18
+
19
+ def weight_init(shape, mode, fan_in, fan_out):
20
+ if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
21
+ if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
22
+ if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
23
+ if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
24
+ raise ValueError(f'Invalid init mode "{mode}"')
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Fully-connected layer.
28
+
29
+ @persistence.persistent_class
30
+ class Linear(torch.nn.Module):
31
+ def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
32
+ super().__init__()
33
+ self.in_features = in_features
34
+ self.out_features = out_features
35
+ init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
36
+ self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
37
+ self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
38
+
39
+ def forward(self, x):
40
+ x = x @ self.weight.to(x.dtype).t()
41
+ if self.bias is not None:
42
+ x = x.add_(self.bias.to(x.dtype))
43
+ return x
44
+
45
+ #----------------------------------------------------------------------------
46
+ # Convolutional layer with optional up/downsampling.
47
+
48
+ @persistence.persistent_class
49
+ class Conv2d(torch.nn.Module):
50
+ def __init__(self,
51
+ in_channels, out_channels, kernel, bias=True, up=False, down=False,
52
+ resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0,
53
+ ):
54
+ assert not (up and down)
55
+ super().__init__()
56
+ self.in_channels = in_channels
57
+ self.out_channels = out_channels
58
+ self.up = up
59
+ self.down = down
60
+ self.fused_resample = fused_resample
61
+ init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel)
62
+ self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None
63
+ self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None
64
+ f = torch.as_tensor(resample_filter, dtype=torch.float32)
65
+ f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
66
+ self.register_buffer('resample_filter', f if up or down else None)
67
+
68
+ def forward(self, x):
69
+ w = self.weight.to(x.dtype) if self.weight is not None else None
70
+ b = self.bias.to(x.dtype) if self.bias is not None else None
71
+ f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None
72
+ w_pad = w.shape[-1] // 2 if w is not None else 0
73
+ f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
74
+
75
+ if self.fused_resample and self.up and w is not None:
76
+ x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0))
77
+ x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
78
+ elif self.fused_resample and self.down and w is not None:
79
+ x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad)
80
+ x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2)
81
+ else:
82
+ if self.up:
83
+ x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
84
+ if self.down:
85
+ x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
86
+ if w is not None:
87
+ x = torch.nn.functional.conv2d(x, w, padding=w_pad)
88
+ if b is not None:
89
+ x = x.add_(b.reshape(1, -1, 1, 1))
90
+ return x
91
+
92
+ #----------------------------------------------------------------------------
93
+ # Group normalization.
94
+
95
+ @persistence.persistent_class
96
+ class GroupNorm(torch.nn.Module):
97
+ def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
98
+ super().__init__()
99
+ self.num_groups = min(num_groups, num_channels // min_channels_per_group)
100
+ self.eps = eps
101
+ self.weight = torch.nn.Parameter(torch.ones(num_channels))
102
+ self.bias = torch.nn.Parameter(torch.zeros(num_channels))
103
+
104
+ def forward(self, x):
105
+ x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps)
106
+ return x
107
+
108
+ #----------------------------------------------------------------------------
109
+ # Attention weight computation, i.e., softmax(Q^T * K).
110
+ # Performs all computation using FP32, but uses the original datatype for
111
+ # inputs/outputs/gradients to conserve memory.
112
+
113
+ class AttentionOp(torch.autograd.Function):
114
+ @staticmethod
115
+ def forward(ctx, q, k):
116
+ w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
117
+ ctx.save_for_backward(q, k, w)
118
+ return w
119
+
120
+ @staticmethod
121
+ def backward(ctx, dw):
122
+ q, k, w = ctx.saved_tensors
123
+ db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
124
+ dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
125
+ dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
126
+ return dq, dk
127
+
128
+ #----------------------------------------------------------------------------
129
+ # Unified U-Net block with optional up/downsampling and self-attention.
130
+ # Represents the union of all features employed by the DDPM++, NCSN++, and
131
+ # ADM architectures.
132
+
133
+ @persistence.persistent_class
134
+ class UNetBlock(torch.nn.Module):
135
+ def __init__(self,
136
+ in_channels, out_channels, emb_channels, up=False, down=False, attention=False,
137
+ num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5,
138
+ resample_filter=[1,1], resample_proj=False, adaptive_scale=True,
139
+ init=dict(), init_zero=dict(init_weight=0), init_attn=None,
140
+ ):
141
+ super().__init__()
142
+ self.in_channels = in_channels
143
+ self.out_channels = out_channels
144
+ self.emb_channels = emb_channels
145
+ self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head
146
+ self.dropout = dropout
147
+ self.skip_scale = skip_scale
148
+ self.adaptive_scale = adaptive_scale
149
+
150
+ self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
151
+ self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init)
152
+ self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init)
153
+ self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
154
+ self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero)
155
+
156
+ self.skip = None
157
+ if out_channels != in_channels or up or down:
158
+ kernel = 1 if resample_proj or out_channels!= in_channels else 0
159
+ self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init)
160
+
161
+ if self.num_heads:
162
+ self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
163
+ self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init))
164
+ self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero)
165
+
166
+ def forward(self, x, emb):
167
+ orig = x
168
+ x = self.conv0(silu(self.norm0(x)))
169
+
170
+ params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
171
+ if self.adaptive_scale:
172
+ scale, shift = params.chunk(chunks=2, dim=1)
173
+ x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
174
+ else:
175
+ x = silu(self.norm1(x.add_(params)))
176
+
177
+ x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
178
+ x = x.add_(self.skip(orig) if self.skip is not None else orig)
179
+ x = x * self.skip_scale
180
+
181
+ if self.num_heads:
182
+ q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
183
+ w = AttentionOp.apply(q, k)
184
+ a = torch.einsum('nqk,nck->ncq', w, v)
185
+ x = self.proj(a.reshape(*x.shape)).add_(x)
186
+ x = x * self.skip_scale
187
+ return x
188
+
189
+ #----------------------------------------------------------------------------
190
+ # Timestep embedding used in the DDPM++ and ADM architectures.
191
+
192
+ @persistence.persistent_class
193
+ class PositionalEmbedding(torch.nn.Module):
194
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
195
+ super().__init__()
196
+ self.num_channels = num_channels
197
+ self.max_positions = max_positions
198
+ self.endpoint = endpoint
199
+
200
+ def forward(self, x):
201
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
202
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
203
+ freqs = (1 / self.max_positions) ** freqs
204
+ x = x.ger(freqs.to(x.dtype))
205
+ x = torch.cat([x.cos(), x.sin()], dim=1)
206
+ return x
207
+
208
+ #----------------------------------------------------------------------------
209
+ # Timestep embedding used in the NCSN++ architecture.
210
+
211
+ @persistence.persistent_class
212
+ class FourierEmbedding(torch.nn.Module):
213
+ def __init__(self, num_channels, scale=16):
214
+ super().__init__()
215
+ self.register_buffer('freqs', torch.randn(num_channels // 2) * scale)
216
+
217
+ def forward(self, x):
218
+ x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
219
+ x = torch.cat([x.cos(), x.sin()], dim=1)
220
+ return x
221
+
222
+ #----------------------------------------------------------------------------
223
+ # Reimplementation of the DDPM++ and NCSN++ architectures from the paper
224
+ # "Score-Based Generative Modeling through Stochastic Differential
225
+ # Equations". Equivalent to the original implementation by Song et al.,
226
+ # available at https://github.com/yang-song/score_sde_pytorch
227
+
228
+ @persistence.persistent_class
229
+ class SongUNet(torch.nn.Module):
230
+ def __init__(self,
231
+ img_resolution, # Image resolution at input/output.
232
+ in_channels, # Number of color channels at input.
233
+ out_channels, # Number of color channels at output.
234
+ label_dim = 0, # Number of class labels, 0 = unconditional.
235
+ augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
236
+
237
+ model_channels = 128, # Base multiplier for the number of channels.
238
+ channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels.
239
+ channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
240
+ num_blocks = 4, # Number of residual blocks per resolution.
241
+ attn_resolutions = [16], # List of resolutions with self-attention.
242
+ dropout = 0.10, # Dropout probability of intermediate activations.
243
+ label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
244
+
245
+ embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
246
+ channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
247
+ encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
248
+ decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
249
+ resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
250
+ ):
251
+ assert embedding_type in ['fourier', 'positional']
252
+ assert encoder_type in ['standard', 'skip', 'residual']
253
+ assert decoder_type in ['standard', 'skip']
254
+
255
+ super().__init__()
256
+ self.label_dropout = label_dropout
257
+ emb_channels = model_channels * channel_mult_emb
258
+ noise_channels = model_channels * channel_mult_noise
259
+ init = dict(init_mode='xavier_uniform')
260
+ init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
261
+ init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
262
+ block_kwargs = dict(
263
+ emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6,
264
+ resample_filter=resample_filter, resample_proj=True, adaptive_scale=False,
265
+ init=init, init_zero=init_zero, init_attn=init_attn,
266
+ )
267
+
268
+ # Mapping.
269
+ self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels)
270
+ self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None
271
+ self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None
272
+ self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
273
+ self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
274
+
275
+ # Encoder.
276
+ self.enc = torch.nn.ModuleDict()
277
+ cout = in_channels
278
+ caux = in_channels
279
+ for level, mult in enumerate(channel_mult):
280
+ res = img_resolution >> level
281
+ if level == 0:
282
+ cin = cout
283
+ cout = model_channels
284
+ self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
285
+ else:
286
+ self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
287
+ if encoder_type == 'skip':
288
+ self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter)
289
+ self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init)
290
+ if encoder_type == 'residual':
291
+ self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init)
292
+ caux = cout
293
+ for idx in range(num_blocks):
294
+ cin = cout
295
+ cout = model_channels * mult
296
+ attn = (res in attn_resolutions)
297
+ self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
298
+ skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name]
299
+
300
+ # Decoder.
301
+ self.dec = torch.nn.ModuleDict()
302
+ for level, mult in reversed(list(enumerate(channel_mult))):
303
+ res = img_resolution >> level
304
+ if level == len(channel_mult) - 1:
305
+ self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
306
+ self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
307
+ else:
308
+ self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
309
+ for idx in range(num_blocks + 1):
310
+ cin = cout + skips.pop()
311
+ cout = model_channels * mult
312
+ attn = (idx == num_blocks and res in attn_resolutions)
313
+ self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
314
+ if decoder_type == 'skip' or level == 0:
315
+ if decoder_type == 'skip' and level < len(channel_mult) - 1:
316
+ self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter)
317
+ self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6)
318
+ self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
319
+
320
+ def forward(self, x, noise_labels, class_labels, augment_labels=None):
321
+ # Mapping.
322
+ emb = self.map_noise(noise_labels)
323
+ emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
324
+ if self.map_label is not None:
325
+ tmp = class_labels
326
+ if self.training and self.label_dropout:
327
+ tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
328
+ emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features))
329
+ if self.map_augment is not None and augment_labels is not None:
330
+ emb = emb + self.map_augment(augment_labels)
331
+ emb = silu(self.map_layer0(emb))
332
+ emb = silu(self.map_layer1(emb))
333
+
334
+ # Encoder.
335
+ skips = []
336
+ aux = x
337
+ for name, block in self.enc.items():
338
+ if 'aux_down' in name:
339
+ aux = block(aux)
340
+ elif 'aux_skip' in name:
341
+ x = skips[-1] = x + block(aux)
342
+ elif 'aux_residual' in name:
343
+ x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2)
344
+ else:
345
+ x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
346
+ skips.append(x)
347
+
348
+ # Decoder.
349
+ aux = None
350
+ tmp = None
351
+ for name, block in self.dec.items():
352
+ if 'aux_up' in name:
353
+ aux = block(aux)
354
+ elif 'aux_norm' in name:
355
+ tmp = block(x)
356
+ elif 'aux_conv' in name:
357
+ tmp = block(silu(tmp))
358
+ aux = tmp if aux is None else tmp + aux
359
+ else:
360
+ if x.shape[1] != block.in_channels:
361
+ x = torch.cat([x, skips.pop()], dim=1)
362
+ x = block(x, emb)
363
+ return aux
364
+
365
+ #----------------------------------------------------------------------------
366
+ # Reimplementation of the ADM architecture from the paper
367
+ # "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the
368
+ # original implementation by Dhariwal and Nichol, available at
369
+ # https://github.com/openai/guided-diffusion
370
+
371
+ @persistence.persistent_class
372
+ class DhariwalUNet(torch.nn.Module):
373
+ def __init__(self,
374
+ img_resolution, # Image resolution at input/output.
375
+ in_channels, # Number of color channels at input.
376
+ out_channels, # Number of color channels at output.
377
+ label_dim = 0, # Number of class labels, 0 = unconditional.
378
+ augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
379
+
380
+ model_channels = 192, # Base multiplier for the number of channels.
381
+ channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
382
+ channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
383
+ num_blocks = 3, # Number of residual blocks per resolution.
384
+ attn_resolutions = [32,16,8], # List of resolutions with self-attention.
385
+ dropout = 0.10, # List of resolutions with self-attention.
386
+ label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
387
+ ):
388
+ super().__init__()
389
+ self.label_dropout = label_dropout
390
+ emb_channels = model_channels * channel_mult_emb
391
+ init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3))
392
+ init_zero = dict(init_mode='kaiming_uniform', init_weight=0, init_bias=0)
393
+ block_kwargs = dict(emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero)
394
+
395
+ # Mapping.
396
+ self.map_noise = PositionalEmbedding(num_channels=model_channels)
397
+ self.map_augment = Linear(in_features=augment_dim, out_features=model_channels, bias=False, **init_zero) if augment_dim else None
398
+ self.map_layer0 = Linear(in_features=model_channels, out_features=emb_channels, **init)
399
+ self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
400
+ self.map_label = Linear(in_features=label_dim, out_features=emb_channels, bias=False, init_mode='kaiming_normal', init_weight=np.sqrt(label_dim)) if label_dim else None
401
+
402
+ # Encoder.
403
+ self.enc = torch.nn.ModuleDict()
404
+ cout = in_channels
405
+ for level, mult in enumerate(channel_mult):
406
+ res = img_resolution >> level
407
+ if level == 0:
408
+ cin = cout
409
+ cout = model_channels * mult
410
+ self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
411
+ else:
412
+ self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
413
+ for idx in range(num_blocks):
414
+ cin = cout
415
+ cout = model_channels * mult
416
+ self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs)
417
+ skips = [block.out_channels for block in self.enc.values()]
418
+
419
+ # Decoder.
420
+ self.dec = torch.nn.ModuleDict()
421
+ for level, mult in reversed(list(enumerate(channel_mult))):
422
+ res = img_resolution >> level
423
+ if level == len(channel_mult) - 1:
424
+ self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
425
+ self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
426
+ else:
427
+ self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
428
+ for idx in range(num_blocks + 1):
429
+ cin = cout + skips.pop()
430
+ cout = model_channels * mult
431
+ self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs)
432
+ self.out_norm = GroupNorm(num_channels=cout)
433
+ self.out_conv = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
434
+
435
+ def forward(self, x, noise_labels, class_labels, augment_labels=None):
436
+ # Mapping.
437
+ emb = self.map_noise(noise_labels)
438
+ if self.map_augment is not None and augment_labels is not None:
439
+ emb = emb + self.map_augment(augment_labels)
440
+ emb = silu(self.map_layer0(emb))
441
+ emb = self.map_layer1(emb)
442
+ if self.map_label is not None:
443
+ tmp = class_labels
444
+ if self.training and self.label_dropout:
445
+ tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
446
+ emb = emb + self.map_label(tmp)
447
+ emb = silu(emb)
448
+
449
+ # Encoder.
450
+ skips = []
451
+ for block in self.enc.values():
452
+ x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
453
+ skips.append(x)
454
+
455
+ # Decoder.
456
+ for block in self.dec.values():
457
+ if x.shape[1] != block.in_channels:
458
+ x = torch.cat([x, skips.pop()], dim=1)
459
+ x = block(x, emb)
460
+ x = self.out_conv(silu(self.out_norm(x)))
461
+ return x
462
+
463
+ #----------------------------------------------------------------------------
464
+ # Preconditioning corresponding to the variance preserving (VP) formulation
465
+ # from the paper "Score-Based Generative Modeling through Stochastic
466
+ # Differential Equations".
467
+
468
+ @persistence.persistent_class
469
+ class VPPrecond(torch.nn.Module):
470
+ def __init__(self,
471
+ img_resolution, # Image resolution.
472
+ img_channels, # Number of color channels.
473
+ label_dim = 0, # Number of class labels, 0 = unconditional.
474
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
475
+ beta_d = 19.9, # Extent of the noise level schedule.
476
+ beta_min = 0.1, # Initial slope of the noise level schedule.
477
+ M = 1000, # Original number of timesteps in the DDPM formulation.
478
+ epsilon_t = 1e-5, # Minimum t-value used during training.
479
+ model_type = 'SongUNet', # Class name of the underlying model.
480
+ **model_kwargs, # Keyword arguments for the underlying model.
481
+ ):
482
+ super().__init__()
483
+ self.img_resolution = img_resolution
484
+ self.img_channels = img_channels
485
+ self.label_dim = label_dim
486
+ self.use_fp16 = use_fp16
487
+ self.beta_d = beta_d
488
+ self.beta_min = beta_min
489
+ self.M = M
490
+ self.epsilon_t = epsilon_t
491
+ self.sigma_min = float(self.sigma(epsilon_t))
492
+ self.sigma_max = float(self.sigma(1))
493
+ self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
494
+
495
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
496
+ x = x.to(torch.float32)
497
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
498
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
499
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
500
+
501
+ c_skip = 1
502
+ c_out = -sigma
503
+ c_in = 1 / (sigma ** 2 + 1).sqrt()
504
+ c_noise = (self.M - 1) * self.sigma_inv(sigma)
505
+
506
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
507
+ assert F_x.dtype == dtype
508
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
509
+ return D_x
510
+
511
+ def sigma(self, t):
512
+ t = torch.as_tensor(t)
513
+ return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
514
+
515
+ def sigma_inv(self, sigma):
516
+ sigma = torch.as_tensor(sigma)
517
+ return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d
518
+
519
+ def round_sigma(self, sigma):
520
+ return torch.as_tensor(sigma)
521
+
522
+ #----------------------------------------------------------------------------
523
+ # Preconditioning corresponding to the variance exploding (VE) formulation
524
+ # from the paper "Score-Based Generative Modeling through Stochastic
525
+ # Differential Equations".
526
+
527
+ @persistence.persistent_class
528
+ class VEPrecond(torch.nn.Module):
529
+ def __init__(self,
530
+ img_resolution, # Image resolution.
531
+ img_channels, # Number of color channels.
532
+ label_dim = 0, # Number of class labels, 0 = unconditional.
533
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
534
+ sigma_min = 0.02, # Minimum supported noise level.
535
+ sigma_max = 100, # Maximum supported noise level.
536
+ model_type = 'SongUNet', # Class name of the underlying model.
537
+ **model_kwargs, # Keyword arguments for the underlying model.
538
+ ):
539
+ super().__init__()
540
+ self.img_resolution = img_resolution
541
+ self.img_channels = img_channels
542
+ self.label_dim = label_dim
543
+ self.use_fp16 = use_fp16
544
+ self.sigma_min = sigma_min
545
+ self.sigma_max = sigma_max
546
+ self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
547
+
548
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
549
+ x = x.to(torch.float32)
550
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
551
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
552
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
553
+
554
+ c_skip = 1
555
+ c_out = sigma
556
+ c_in = 1
557
+ c_noise = (0.5 * sigma).log()
558
+
559
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
560
+ assert F_x.dtype == dtype
561
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
562
+ return D_x
563
+
564
+ def round_sigma(self, sigma):
565
+ return torch.as_tensor(sigma)
566
+
567
+ #----------------------------------------------------------------------------
568
+ # Preconditioning corresponding to improved DDPM (iDDPM) formulation from
569
+ # the paper "Improved Denoising Diffusion Probabilistic Models".
570
+
571
+ @persistence.persistent_class
572
+ class iDDPMPrecond(torch.nn.Module):
573
+ def __init__(self,
574
+ img_resolution, # Image resolution.
575
+ img_channels, # Number of color channels.
576
+ label_dim = 0, # Number of class labels, 0 = unconditional.
577
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
578
+ C_1 = 0.001, # Timestep adjustment at low noise levels.
579
+ C_2 = 0.008, # Timestep adjustment at high noise levels.
580
+ M = 1000, # Original number of timesteps in the DDPM formulation.
581
+ model_type = 'DhariwalUNet', # Class name of the underlying model.
582
+ **model_kwargs, # Keyword arguments for the underlying model.
583
+ ):
584
+ super().__init__()
585
+ self.img_resolution = img_resolution
586
+ self.img_channels = img_channels
587
+ self.label_dim = label_dim
588
+ self.use_fp16 = use_fp16
589
+ self.C_1 = C_1
590
+ self.C_2 = C_2
591
+ self.M = M
592
+ self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels*2, label_dim=label_dim, **model_kwargs)
593
+
594
+ u = torch.zeros(M + 1)
595
+ for j in range(M, 0, -1): # M, ..., 1
596
+ u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
597
+ self.register_buffer('u', u)
598
+ self.sigma_min = float(u[M - 1])
599
+ self.sigma_max = float(u[0])
600
+
601
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
602
+ x = x.to(torch.float32)
603
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
604
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
605
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
606
+
607
+ c_skip = 1
608
+ c_out = -sigma
609
+ c_in = 1 / (sigma ** 2 + 1).sqrt()
610
+ c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
611
+
612
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
613
+ assert F_x.dtype == dtype
614
+ D_x = c_skip * x + c_out * F_x[:, :self.img_channels].to(torch.float32)
615
+ return D_x
616
+
617
+ def alpha_bar(self, j):
618
+ j = torch.as_tensor(j)
619
+ return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
620
+
621
+ def round_sigma(self, sigma, return_index=False):
622
+ sigma = torch.as_tensor(sigma)
623
+ index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
624
+ result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
625
+ return result.reshape(sigma.shape).to(sigma.device)
626
+
627
+ #----------------------------------------------------------------------------
628
+ # Improved preconditioning proposed in the paper "Elucidating the Design
629
+ # Space of Diffusion-Based Generative Models" (EDM).
630
+
631
+ @persistence.persistent_class
632
+ class EDMPrecond(torch.nn.Module):
633
+ def __init__(self,
634
+ img_resolution, # Image resolution.
635
+ img_channels, # Number of color channels.
636
+ label_dim = 0, # Number of class labels, 0 = unconditional.
637
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
638
+ sigma_min = 0, # Minimum supported noise level.
639
+ sigma_max = float('inf'), # Maximum supported noise level.
640
+ sigma_data = 0.5, # Expected standard deviation of the training data.
641
+ model_type = 'DhariwalUNet', # Class name of the underlying model.
642
+ **model_kwargs, # Keyword arguments for the underlying model.
643
+ ):
644
+ super().__init__()
645
+ self.img_resolution = img_resolution
646
+ self.img_channels = img_channels
647
+ self.label_dim = label_dim
648
+ self.use_fp16 = use_fp16
649
+ self.sigma_min = sigma_min
650
+ self.sigma_max = sigma_max
651
+ self.sigma_data = sigma_data
652
+ self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
653
+
654
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
655
+ x = x.to(torch.float32)
656
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
657
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
658
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
659
+
660
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
661
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
662
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
663
+ c_noise = sigma.log() / 4
664
+
665
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
666
+ assert F_x.dtype == dtype
667
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
668
+ return D_x
669
+
670
+ def round_sigma(self, sigma):
671
+ return torch.as_tensor(sigma)
672
+
673
+ #----------------------------------------------------------------------------
edm/training/training_loop.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Main training loop."""
9
+
10
+ import os
11
+ import time
12
+ import copy
13
+ import json
14
+ import pickle
15
+ import psutil
16
+ import numpy as np
17
+ import torch
18
+ import dnnlib
19
+ from torch_utils import distributed as dist
20
+ from torch_utils import training_stats
21
+ from torch_utils import misc
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def training_loop(
26
+ run_dir = '.', # Output directory.
27
+ dataset_kwargs = {}, # Options for training set.
28
+ data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
29
+ network_kwargs = {}, # Options for model and preconditioning.
30
+ loss_kwargs = {}, # Options for loss function.
31
+ optimizer_kwargs = {}, # Options for optimizer.
32
+ augment_kwargs = None, # Options for augmentation pipeline, None = disable.
33
+ seed = 0, # Global random seed.
34
+ batch_size = 512, # Total batch size for one training iteration.
35
+ batch_gpu = None, # Limit batch size per GPU, None = no limit.
36
+ total_kimg = 200000, # Training duration, measured in thousands of training images.
37
+ ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
38
+ ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
39
+ lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
40
+ loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
41
+ kimg_per_tick = 50, # Interval of progress prints.
42
+ snapshot_ticks = 50, # How often to save network snapshots, None = disable.
43
+ state_dump_ticks = 500, # How often to dump training state, None = disable.
44
+ resume_pkl = None, # Start from the given network snapshot, None = random initialization.
45
+ resume_state_dump = None, # Start from the given training state, None = reset training state.
46
+ resume_kimg = 0, # Start from the given training progress.
47
+ cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
48
+ device = torch.device('cuda'),
49
+ ):
50
+ # Initialize.
51
+ start_time = time.time()
52
+ np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
53
+ torch.manual_seed(np.random.randint(1 << 31))
54
+ torch.backends.cudnn.benchmark = cudnn_benchmark
55
+ torch.backends.cudnn.allow_tf32 = False
56
+ torch.backends.cuda.matmul.allow_tf32 = False
57
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
58
+
59
+ # Select batch size per GPU.
60
+ batch_gpu_total = batch_size // dist.get_world_size()
61
+ if batch_gpu is None or batch_gpu > batch_gpu_total:
62
+ batch_gpu = batch_gpu_total
63
+ num_accumulation_rounds = batch_gpu_total // batch_gpu
64
+ assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
65
+
66
+ # Load dataset.
67
+ dist.print0('Loading dataset...')
68
+ dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
69
+ dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
70
+ dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
71
+
72
+ # Construct network.
73
+ dist.print0('Constructing network...')
74
+ interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
75
+ net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
76
+ net.train().requires_grad_(True).to(device)
77
+ if dist.get_rank() == 0:
78
+ with torch.no_grad():
79
+ images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
80
+ sigma = torch.ones([batch_gpu], device=device)
81
+ labels = torch.zeros([batch_gpu, net.label_dim], device=device)
82
+ misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
83
+
84
+ # Setup optimizer.
85
+ dist.print0('Setting up optimizer...')
86
+ loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
87
+ optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
88
+ augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
89
+ ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
90
+ ema = copy.deepcopy(net).eval().requires_grad_(False)
91
+
92
+ # Resume training from previous snapshot.
93
+ if resume_pkl is not None:
94
+ dist.print0(f'Loading network weights from "{resume_pkl}"...')
95
+ if dist.get_rank() != 0:
96
+ torch.distributed.barrier() # rank 0 goes first
97
+ with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
98
+ data = pickle.load(f)
99
+ if dist.get_rank() == 0:
100
+ torch.distributed.barrier() # other ranks follow
101
+ misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
102
+ misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
103
+ del data # conserve memory
104
+ if resume_state_dump:
105
+ dist.print0(f'Loading training state from "{resume_state_dump}"...')
106
+ data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
107
+ misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
108
+ optimizer.load_state_dict(data['optimizer_state'])
109
+ del data # conserve memory
110
+
111
+ # Train.
112
+ dist.print0(f'Training for {total_kimg} kimg...')
113
+ dist.print0()
114
+ cur_nimg = resume_kimg * 1000
115
+ cur_tick = 0
116
+ tick_start_nimg = cur_nimg
117
+ tick_start_time = time.time()
118
+ maintenance_time = tick_start_time - start_time
119
+ dist.update_progress(cur_nimg // 1000, total_kimg)
120
+ stats_jsonl = None
121
+ while True:
122
+
123
+ # Accumulate gradients.
124
+ optimizer.zero_grad(set_to_none=True)
125
+ for round_idx in range(num_accumulation_rounds):
126
+ with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
127
+ images, labels = next(dataset_iterator)
128
+ images = images.to(device).to(torch.float32) / 127.5 - 1
129
+ labels = labels.to(device)
130
+ loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
131
+ training_stats.report('Loss/loss', loss)
132
+ loss.sum().mul(loss_scaling / batch_gpu_total).backward()
133
+
134
+ # Update weights.
135
+ for g in optimizer.param_groups:
136
+ g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
137
+ for param in net.parameters():
138
+ if param.grad is not None:
139
+ torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
140
+ optimizer.step()
141
+
142
+ # Update EMA.
143
+ ema_halflife_nimg = ema_halflife_kimg * 1000
144
+ if ema_rampup_ratio is not None:
145
+ ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
146
+ ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
147
+ for p_ema, p_net in zip(ema.parameters(), net.parameters()):
148
+ p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
149
+
150
+ # Perform maintenance tasks once per tick.
151
+ cur_nimg += batch_size
152
+ done = (cur_nimg >= total_kimg * 1000)
153
+ if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
154
+ continue
155
+
156
+ # Print status line, accumulating the same information in training_stats.
157
+ tick_end_time = time.time()
158
+ fields = []
159
+ fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
160
+ fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
161
+ fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
162
+ fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
163
+ fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
164
+ fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
165
+ fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
166
+ fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
167
+ fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
168
+ torch.cuda.reset_peak_memory_stats()
169
+ dist.print0(' '.join(fields))
170
+
171
+ # Check for abort.
172
+ if (not done) and dist.should_stop():
173
+ done = True
174
+ dist.print0()
175
+ dist.print0('Aborting...')
176
+
177
+ # Save network snapshot.
178
+ if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
179
+ data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
180
+ for key, value in data.items():
181
+ if isinstance(value, torch.nn.Module):
182
+ value = copy.deepcopy(value).eval().requires_grad_(False)
183
+ misc.check_ddp_consistency(value)
184
+ data[key] = value.cpu()
185
+ del value # conserve memory
186
+ if dist.get_rank() == 0:
187
+ with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
188
+ pickle.dump(data, f)
189
+ del data # conserve memory
190
+
191
+ # Save full dump of the training state.
192
+ if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
193
+ torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
194
+
195
+ # Update logs.
196
+ training_stats.default_collector.update()
197
+ if dist.get_rank() == 0:
198
+ if stats_jsonl is None:
199
+ stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
200
+ stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
201
+ stats_jsonl.flush()
202
+ dist.update_progress(cur_nimg // 1000, total_kimg)
203
+
204
+ # Update state.
205
+ cur_tick += 1
206
+ tick_start_nimg = cur_nimg
207
+ tick_start_time = time.time()
208
+ maintenance_time = tick_start_time - tick_end_time
209
+ if done:
210
+ break
211
+
212
+ # Done.
213
+ dist.print0()
214
+ dist.print0('Exiting...')
215
+
216
+ #----------------------------------------------------------------------------