44 | This mock-up mimics the look of the in-progress article to inform a design
45 | that embeds the demo into the article. The relevant assets just need to be
46 | migrated into the final article.
47 |
48 |
49 |
We’ve developed Glow, a new type of generative model which uses
50 | invertible 1x1 convolutions to create rich, synthetic models of data,
51 | automatically discovering features we can manipulate. The model extends
52 | previous work on reversible generative models, simplifying the
53 | architecture and leading to substantially better results. We’re releasing
54 | code for the model and an online visualization tool so people can explore
55 | and build on these results.
Generative modeling is about observing data, like a set of pictures of
69 | faces, then learning a model of how this data was generated. Learning to
70 | approximate the data-generating process requires learning all structure
71 | present in the data, and successful models should be able to synthesize
72 | outputs that look similar to the data. Accurate generative models have
73 | broad applications, including speech synthesis, text analysis and
74 | synthesis, semi-supervised learning and model-based control. The technique
75 | we propose can be applied to those problems as well.
76 |
77 |
78 |
--------------------------------------------------------------------------------
/demo/web/load-image.all.min.js:
--------------------------------------------------------------------------------
1 | !function(e){"use strict";function t(e,i,a){var o,n=document.createElement("img");return n.onerror=function(o){return t.onerror(n,o,e,i,a)},n.onload=function(o){return t.onload(n,o,e,i,a)},"string"==typeof e?(t.fetchBlob(e,function(i){i?(e=i,o=t.createObjectURL(e)):(o=e,a&&a.crossOrigin&&(n.crossOrigin=a.crossOrigin)),n.src=o},a),n):t.isInstanceOf("Blob",e)||t.isInstanceOf("File",e)?(o=n._objectURL=t.createObjectURL(e))?(n.src=o,n):t.readFile(e,function(e){var t=e.target;t&&t.result?n.src=t.result:i&&i(e)}):void 0}function i(e,i){!e._objectURL||i&&i.noRevoke||(t.revokeObjectURL(e._objectURL),delete e._objectURL)}var a=e.createObjectURL&&e||e.URL&&URL.revokeObjectURL&&URL||e.webkitURL&&webkitURL;t.fetchBlob=function(e,t,i){t()},t.isInstanceOf=function(e,t){return Object.prototype.toString.call(t)==="[object "+e+"]"},t.transform=function(e,t,i,a,o){i(e,o)},t.onerror=function(e,t,a,o,n){i(e,n),o&&o.call(e,t)},t.onload=function(e,a,o,n,r){i(e,r),n&&t.transform(e,r,n,o,{})},t.createObjectURL=function(e){return!!a&&a.createObjectURL(e)},t.revokeObjectURL=function(e){return!!a&&a.revokeObjectURL(e)},t.readFile=function(t,i,a){if(e.FileReader){var o=new FileReader;if(o.onload=o.onerror=i,a=a||"readAsDataURL",o[a])return o[a](t),o}return!1},"function"==typeof define&&define.amd?define(function(){return t}):"object"==typeof module&&module.exports?module.exports=t:e.loadImage=t}("undefined"!=typeof window&&window||this),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t=e.transform;e.transform=function(i,a,o,n,r){t.call(e,e.scale(i,a,r),a,o,n,r)},e.transformCoordinates=function(){},e.getTransformedOptions=function(e,t){var i,a,o,n,r=t.aspectRatio;if(!r)return t;i={};for(a in t)t.hasOwnProperty(a)&&(i[a]=t[a]);return i.crop=!0,o=e.naturalWidth||e.width,n=e.naturalHeight||e.height,o/n>r?(i.maxWidth=n*r,i.maxHeight=n):(i.maxWidth=o,i.maxHeight=o/r),i},e.renderImageToCanvas=function(e,t,i,a,o,n,r,s,l,d){return e.getContext("2d").drawImage(t,i,a,o,n,r,s,l,d),e},e.hasCanvasOption=function(e){return e.canvas||e.crop||!!e.aspectRatio},e.scale=function(t,i,a){function o(){var e=Math.max((l||v)/v,(d||P)/P);e>1&&(v*=e,P*=e)}function n(){var e=Math.min((r||v)/v,(s||P)/P);e<1&&(v*=e,P*=e)}i=i||{};var r,s,l,d,c,u,f,g,h,m,p,S=document.createElement("canvas"),b=t.getContext||e.hasCanvasOption(i)&&S.getContext,y=t.naturalWidth||t.width,x=t.naturalHeight||t.height,v=y,P=x;if(b&&(f=(i=e.getTransformedOptions(t,i,a)).left||0,g=i.top||0,i.sourceWidth?(c=i.sourceWidth,void 0!==i.right&&void 0===i.left&&(f=y-c-i.right)):c=y-f-(i.right||0),i.sourceHeight?(u=i.sourceHeight,void 0!==i.bottom&&void 0===i.top&&(g=x-u-i.bottom)):u=x-g-(i.bottom||0),v=c,P=u),r=i.maxWidth,s=i.maxHeight,l=i.minWidth,d=i.minHeight,b&&r&&s&&i.crop?(v=r,P=s,(p=c/u-r/s)<0?(u=s*c/r,void 0===i.top&&void 0===i.bottom&&(g=(x-u)/2)):p>0&&(c=r*u/s,void 0===i.left&&void 0===i.right&&(f=(y-c)/2))):((i.contain||i.cover)&&(l=r=r||l,d=s=s||d),i.cover?(n(),o()):(o(),n())),b){if((h=i.pixelRatio)>1&&(S.style.width=v+"px",S.style.height=P+"px",v*=h,P*=h,S.getContext("2d").scale(h,h)),(m=i.downsamplingRatio)>0&&m<1&&vv;)S.width=c*m,S.height=u*m,e.renderImageToCanvas(S,t,f,g,c,u,0,0,S.width,S.height),f=0,g=0,c=S.width,u=S.height,(t=document.createElement("canvas")).width=c,t.height=u,e.renderImageToCanvas(t,S,0,0,c,u,0,0,c,u);return S.width=v,S.height=P,e.transformCoordinates(S,i),e.renderImageToCanvas(S,t,f,g,c,u,0,0,v,P)}return t.width=v,t.height=P,t}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t="undefined"!=typeof Blob&&(Blob.prototype.slice||Blob.prototype.webkitSlice||Blob.prototype.mozSlice);e.blobSlice=t&&function(){return(this.slice||this.webkitSlice||this.mozSlice).apply(this,arguments)},e.metaDataParsers={jpeg:{65505:[]}},e.parseMetaData=function(t,i,a,o){a=a||{},o=o||{};var n=this,r=a.maxMetaDataSize||262144;!!("undefined"!=typeof DataView&&t&&t.size>=12&&"image/jpeg"===t.type&&e.blobSlice)&&e.readFile(e.blobSlice.call(t,0,r),function(t){if(t.target.error)return console.log(t.target.error),void i(o);var r,s,l,d,c=t.target.result,u=new DataView(c),f=2,g=u.byteLength-4,h=f;if(65496===u.getUint16(0)){for(;f=65504&&r<=65519||65534===r);){if(s=u.getUint16(f+2)+2,f+s>u.byteLength){console.log("Invalid meta data: Invalid segment size.");break}if(l=e.metaDataParsers.jpeg[r])for(d=0;d6&&(c.slice?o.imageHead=c.slice(0,h):o.imageHead=new Uint8Array(c).subarray(0,h))}else console.log("Invalid JPEG file: Missing JPEG marker.");i(o)},"readAsArrayBuffer")||i(o)},e.hasMetaOption=function(e){return e&&e.meta};var i=e.transform;e.transform=function(t,a,o,n,r){e.hasMetaOption(a)?e.parseMetaData(n,function(r){i.call(e,t,a,o,n,r)},a,r):i.apply(e,arguments)}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";"undefined"!=typeof fetch&&"undefined"!=typeof Request&&(e.fetchBlob=function(t,i,a){if(e.hasMetaOption(a))return fetch(new Request(t,a)).then(function(e){return e.blob()}).then(i).catch(function(e){console.log(e),i()});i()})}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";e.ExifMap=function(){return this},e.ExifMap.prototype.map={Orientation:274},e.ExifMap.prototype.get=function(e){return this[e]||this[this.map[e]]},e.getExifThumbnail=function(t,i,a){if(a&&!(i+a>t.byteLength))return e.createObjectURL(new Blob([t.buffer.slice(i,i+a)]));console.log("Invalid Exif data: Invalid thumbnail data.")},e.exifTagTypes={1:{getValue:function(e,t){return e.getUint8(t)},size:1},2:{getValue:function(e,t){return String.fromCharCode(e.getUint8(t))},size:1,ascii:!0},3:{getValue:function(e,t,i){return e.getUint16(t,i)},size:2},4:{getValue:function(e,t,i){return e.getUint32(t,i)},size:4},5:{getValue:function(e,t,i){return e.getUint32(t,i)/e.getUint32(t+4,i)},size:8},9:{getValue:function(e,t,i){return e.getInt32(t,i)},size:4},10:{getValue:function(e,t,i){return e.getInt32(t,i)/e.getInt32(t+4,i)},size:8}},e.exifTagTypes[7]=e.exifTagTypes[1],e.getExifValue=function(t,i,a,o,n,r){var s,l,d,c,u,f,g=e.exifTagTypes[o];if(g){if(s=g.size*n,!((l=s>4?i+t.getUint32(a+8,r):a+8)+s>t.byteLength)){if(1===n)return g.getValue(t,l,r);for(d=[],c=0;ce.byteLength)console.log("Invalid Exif data: Invalid directory offset.");else{if(n=e.getUint16(i,a),!((r=i+2+12*n)+4>e.byteLength)){for(s=0;st.byteLength)console.log("Invalid Exif data: Invalid segment size.");else if(0===t.getUint16(i+8)){switch(t.getUint16(d)){case 18761:r=!0;break;case 19789:r=!1;break;default:return void console.log("Invalid Exif data: Invalid byte alignment marker.")}42===t.getUint16(d+2,r)?(s=t.getUint32(d+4,r),o.exif=new e.ExifMap,(s=e.parseExifTags(t,d,d+s,r,o))&&!n.disableExifThumbnail&&(l={exif:{}},s=e.parseExifTags(t,d,d+s,r,l),l.exif[513]&&(o.exif.Thumbnail=e.getExifThumbnail(t,d+l.exif[513],l.exif[514]))),o.exif[34665]&&!n.disableExifSub&&e.parseExifTags(t,d,d+o.exif[34665],r,o),o.exif[34853]&&!n.disableExifGps&&e.parseExifTags(t,d,d+o.exif[34853],r,o)):console.log("Invalid Exif data: Missing TIFF marker.")}else console.log("Invalid Exif data: Missing byte alignment offset.")}},e.metaDataParsers.jpeg[65505].push(e.parseExifData)}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-exif"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-exif")):e(window.loadImage)}(function(e){"use strict";e.ExifMap.prototype.tags={256:"ImageWidth",257:"ImageHeight",34665:"ExifIFDPointer",34853:"GPSInfoIFDPointer",40965:"InteroperabilityIFDPointer",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",274:"Orientation",277:"SamplesPerPixel",284:"PlanarConfiguration",530:"YCbCrSubSampling",531:"YCbCrPositioning",282:"XResolution",283:"YResolution",296:"ResolutionUnit",273:"StripOffsets",278:"RowsPerStrip",279:"StripByteCounts",513:"JPEGInterchangeFormat",514:"JPEGInterchangeFormatLength",301:"TransferFunction",318:"WhitePoint",319:"PrimaryChromaticities",529:"YCbCrCoefficients",532:"ReferenceBlackWhite",306:"DateTime",270:"ImageDescription",271:"Make",272:"Model",305:"Software",315:"Artist",33432:"Copyright",36864:"ExifVersion",40960:"FlashpixVersion",40961:"ColorSpace",40962:"PixelXDimension",40963:"PixelYDimension",42240:"Gamma",37121:"ComponentsConfiguration",37122:"CompressedBitsPerPixel",37500:"MakerNote",37510:"UserComment",40964:"RelatedSoundFile",36867:"DateTimeOriginal",36868:"DateTimeDigitized",37520:"SubSecTime",37521:"SubSecTimeOriginal",37522:"SubSecTimeDigitized",33434:"ExposureTime",33437:"FNumber",34850:"ExposureProgram",34852:"SpectralSensitivity",34855:"PhotographicSensitivity",34856:"OECF",34864:"SensitivityType",34865:"StandardOutputSensitivity",34866:"RecommendedExposureIndex",34867:"ISOSpeed",34868:"ISOSpeedLatitudeyyy",34869:"ISOSpeedLatitudezzz",37377:"ShutterSpeedValue",37378:"ApertureValue",37379:"BrightnessValue",37380:"ExposureBias",37381:"MaxApertureValue",37382:"SubjectDistance",37383:"MeteringMode",37384:"LightSource",37385:"Flash",37396:"SubjectArea",37386:"FocalLength",41483:"FlashEnergy",41484:"SpatialFrequencyResponse",41486:"FocalPlaneXResolution",41487:"FocalPlaneYResolution",41488:"FocalPlaneResolutionUnit",41492:"SubjectLocation",41493:"ExposureIndex",41495:"SensingMethod",41728:"FileSource",41729:"SceneType",41730:"CFAPattern",41985:"CustomRendered",41986:"ExposureMode",41987:"WhiteBalance",41988:"DigitalZoomRatio",41989:"FocalLengthIn35mmFilm",41990:"SceneCaptureType",41991:"GainControl",41992:"Contrast",41993:"Saturation",41994:"Sharpness",41995:"DeviceSettingDescription",41996:"SubjectDistanceRange",42016:"ImageUniqueID",42032:"CameraOwnerName",42033:"BodySerialNumber",42034:"LensSpecification",42035:"LensMake",42036:"LensModel",42037:"LensSerialNumber",0:"GPSVersionID",1:"GPSLatitudeRef",2:"GPSLatitude",3:"GPSLongitudeRef",4:"GPSLongitude",5:"GPSAltitudeRef",6:"GPSAltitude",7:"GPSTimeStamp",8:"GPSSatellites",9:"GPSStatus",10:"GPSMeasureMode",11:"GPSDOP",12:"GPSSpeedRef",13:"GPSSpeed",14:"GPSTrackRef",15:"GPSTrack",16:"GPSImgDirectionRef",17:"GPSImgDirection",18:"GPSMapDatum",19:"GPSDestLatitudeRef",20:"GPSDestLatitude",21:"GPSDestLongitudeRef",22:"GPSDestLongitude",23:"GPSDestBearingRef",24:"GPSDestBearing",25:"GPSDestDistanceRef",26:"GPSDestDistance",27:"GPSProcessingMethod",28:"GPSAreaInformation",29:"GPSDateStamp",30:"GPSDifferential",31:"GPSHPositioningError"},e.ExifMap.prototype.stringValues={ExposureProgram:{0:"Undefined",1:"Manual",2:"Normal program",3:"Aperture priority",4:"Shutter priority",5:"Creative program",6:"Action program",7:"Portrait mode",8:"Landscape mode"},MeteringMode:{0:"Unknown",1:"Average",2:"CenterWeightedAverage",3:"Spot",4:"MultiSpot",5:"Pattern",6:"Partial",255:"Other"},LightSource:{0:"Unknown",1:"Daylight",2:"Fluorescent",3:"Tungsten (incandescent light)",4:"Flash",9:"Fine weather",10:"Cloudy weather",11:"Shade",12:"Daylight fluorescent (D 5700 - 7100K)",13:"Day white fluorescent (N 4600 - 5400K)",14:"Cool white fluorescent (W 3900 - 4500K)",15:"White fluorescent (WW 3200 - 3700K)",17:"Standard light A",18:"Standard light B",19:"Standard light C",20:"D55",21:"D65",22:"D75",23:"D50",24:"ISO studio tungsten",255:"Other"},Flash:{0:"Flash did not fire",1:"Flash fired",5:"Strobe return light not detected",7:"Strobe return light detected",9:"Flash fired, compulsory flash mode",13:"Flash fired, compulsory flash mode, return light not detected",15:"Flash fired, compulsory flash mode, return light detected",16:"Flash did not fire, compulsory flash mode",24:"Flash did not fire, auto mode",25:"Flash fired, auto mode",29:"Flash fired, auto mode, return light not detected",31:"Flash fired, auto mode, return light detected",32:"No flash function",65:"Flash fired, red-eye reduction mode",69:"Flash fired, red-eye reduction mode, return light not detected",71:"Flash fired, red-eye reduction mode, return light detected",73:"Flash fired, compulsory flash mode, red-eye reduction mode",77:"Flash fired, compulsory flash mode, red-eye reduction mode, return light not detected",79:"Flash fired, compulsory flash mode, red-eye reduction mode, return light detected",89:"Flash fired, auto mode, red-eye reduction mode",93:"Flash fired, auto mode, return light not detected, red-eye reduction mode",95:"Flash fired, auto mode, return light detected, red-eye reduction mode"},SensingMethod:{1:"Undefined",2:"One-chip color area sensor",3:"Two-chip color area sensor",4:"Three-chip color area sensor",5:"Color sequential area sensor",7:"Trilinear sensor",8:"Color sequential linear sensor"},SceneCaptureType:{0:"Standard",1:"Landscape",2:"Portrait",3:"Night scene"},SceneType:{1:"Directly photographed"},CustomRendered:{0:"Normal process",1:"Custom process"},WhiteBalance:{0:"Auto white balance",1:"Manual white balance"},GainControl:{0:"None",1:"Low gain up",2:"High gain up",3:"Low gain down",4:"High gain down"},Contrast:{0:"Normal",1:"Soft",2:"Hard"},Saturation:{0:"Normal",1:"Low saturation",2:"High saturation"},Sharpness:{0:"Normal",1:"Soft",2:"Hard"},SubjectDistanceRange:{0:"Unknown",1:"Macro",2:"Close view",3:"Distant view"},FileSource:{3:"DSC"},ComponentsConfiguration:{0:"",1:"Y",2:"Cb",3:"Cr",4:"R",5:"G",6:"B"},Orientation:{1:"top-left",2:"top-right",3:"bottom-right",4:"bottom-left",5:"left-top",6:"right-top",7:"right-bottom",8:"left-bottom"}},e.ExifMap.prototype.getText=function(e){var t=this.get(e);switch(e){case"LightSource":case"Flash":case"MeteringMode":case"ExposureProgram":case"SensingMethod":case"SceneCaptureType":case"SceneType":case"CustomRendered":case"WhiteBalance":case"GainControl":case"Contrast":case"Saturation":case"Sharpness":case"SubjectDistanceRange":case"FileSource":case"Orientation":return this.stringValues[e][t];case"ExifVersion":case"FlashpixVersion":if(!t)return;return String.fromCharCode(t[0],t[1],t[2],t[3]);case"ComponentsConfiguration":if(!t)return;return this.stringValues[e][t[0]]+this.stringValues[e][t[1]]+this.stringValues[e][t[2]]+this.stringValues[e][t[3]];case"GPSVersionID":if(!t)return;return t[0]+"."+t[1]+"."+t[2]+"."+t[3]}return String(t)},function(e){var t,i=e.tags,a=e.map;for(t in i)i.hasOwnProperty(t)&&(a[i[t]]=t)}(e.ExifMap.prototype),e.ExifMap.prototype.getAll=function(){var e,t,i={};for(e in this)this.hasOwnProperty(e)&&(t=this.tags[e])&&(i[t]=this.getText(t));return i}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-scale","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-scale"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";var t=e.hasCanvasOption,i=e.hasMetaOption,a=e.transformCoordinates,o=e.getTransformedOptions;e.hasCanvasOption=function(i){return!!i.orientation||t.call(e,i)},e.hasMetaOption=function(t){return t&&!0===t.orientation||i.call(e,t)},e.transformCoordinates=function(t,i){a.call(e,t,i);var o=t.getContext("2d"),n=t.width,r=t.height,s=t.style.width,l=t.style.height,d=i.orientation;if(d&&!(d>8))switch(d>4&&(t.width=r,t.height=n,t.style.width=l,t.style.height=s),d){case 2:o.translate(n,0),o.scale(-1,1);break;case 3:o.translate(n,r),o.rotate(Math.PI);break;case 4:o.translate(0,r),o.scale(1,-1);break;case 5:o.rotate(.5*Math.PI),o.scale(1,-1);break;case 6:o.rotate(.5*Math.PI),o.translate(0,-r);break;case 7:o.rotate(.5*Math.PI),o.translate(n,-r),o.scale(-1,1);break;case 8:o.rotate(-.5*Math.PI),o.translate(-n,0)}},e.getTransformedOptions=function(t,i,a){var n,r,s=o.call(e,t,i),l=s.orientation;if(!0===l&&a&&a.exif&&(l=a.exif.get("Orientation")),!l||l>8||1===l)return s;n={};for(r in s)s.hasOwnProperty(r)&&(n[r]=s[r]);switch(n.orientation=l,l){case 2:n.left=s.right,n.right=s.left;break;case 3:n.left=s.right,n.top=s.bottom,n.right=s.left,n.bottom=s.top;break;case 4:n.top=s.bottom,n.bottom=s.top;break;case 5:n.left=s.top,n.top=s.left,n.right=s.bottom,n.bottom=s.right;break;case 6:n.left=s.top,n.top=s.right,n.right=s.bottom,n.bottom=s.left;break;case 7:n.left=s.bottom,n.top=s.right,n.right=s.top,n.bottom=s.left;break;case 8:n.left=s.bottom,n.top=s.left,n.right=s.top,n.bottom=s.right}return n.orientation>4&&(n.maxWidth=s.maxHeight,n.maxHeight=s.maxWidth,n.minWidth=s.minHeight,n.minHeight=s.minWidth,n.sourceWidth=s.sourceHeight,n.sourceHeight=s.sourceWidth),n}});
2 | //# sourceMappingURL=load-image.all.min.js.map
3 |
--------------------------------------------------------------------------------
/demo/web/load-image.all.min.js.map:
--------------------------------------------------------------------------------
1 | {"version":3,"sources":["load-image.js","load-image-scale.js","load-image-meta.js","load-image-fetch.js","load-image-exif.js","load-image-exif-map.js","load-image-orientation.js"],"names":["$","loadImage","file","callback","options","url","img","document","createElement","onerror","event","onload","fetchBlob","blob","createObjectURL","crossOrigin","src","isInstanceOf","_objectURL","readFile","e","target","result","revokeHelper","noRevoke","revokeObjectURL","urlAPI","URL","webkitURL","type","obj","Object","prototype","toString","call","transform","data","method","FileReader","fileReader","define","amd","module","exports","window","this","factory","require","originalTransform","scale","transformCoordinates","getTransformedOptions","newOptions","i","width","height","aspectRatio","hasOwnProperty","crop","naturalWidth","naturalHeight","maxWidth","maxHeight","renderImageToCanvas","canvas","sourceX","sourceY","sourceWidth","sourceHeight","destX","destY","destWidth","destHeight","getContext","drawImage","hasCanvasOption","scaleUp","Math","max","minWidth","minHeight","scaleDown","min","pixelRatio","downsamplingRatio","tmp","useCanvas","left","top","undefined","right","bottom","contain","cover","style","hasblobSlice","Blob","slice","webkitSlice","mozSlice","blobSlice","apply","arguments","metaDataParsers","jpeg","65505","parseMetaData","that","maxMetaDataSize","DataView","size","error","console","log","markerBytes","markerLength","parsers","buffer","dataView","offset","maxOffset","byteLength","headLength","getUint16","length","disableImageHead","imageHead","Uint8Array","subarray","hasMetaOption","meta","fetch","Request","then","response","catch","err","ExifMap","map","Orientation","get","id","getExifThumbnail","exifTagTypes","1","getValue","dataOffset","getUint8","2","String","fromCharCode","ascii","3","littleEndian","4","getUint32","5","9","getInt32","10","getExifValue","tiffOffset","tagSize","values","str","c","tagType","parseExifTag","tag","exif","parseExifTags","dirOffset","tagsNumber","dirEndOffset","parseExifData","disableExif","thumbnailData","disableExifThumbnail","Thumbnail","disableExifSub","disableExifGps","push","tags","256","257","34665","34853","40965","258","259","262","274","277","284","530","531","282","283","296","273","278","279","513","514","301","318","319","529","532","306","270","271","272","305","315","33432","36864","40960","40961","40962","40963","42240","37121","37122","37500","37510","40964","36867","36868","37520","37521","37522","33434","33437","34850","34852","34855","34856","34864","34865","34866","34867","34868","34869","37377","37378","37379","37380","37381","37382","37383","37384","37385","37396","37386","41483","41484","41486","41487","41488","41492","41493","41495","41728","41729","41730","41985","41986","41987","41988","41989","41990","41991","41992","41993","41994","41995","41996","42016","42032","42033","42034","42035","42036","42037","0","6","7","8","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","stringValues","ExposureProgram","MeteringMode","255","LightSource","Flash","32","65","69","71","73","77","79","89","93","95","SensingMethod","SceneCaptureType","SceneType","CustomRendered","WhiteBalance","GainControl","Contrast","Saturation","Sharpness","SubjectDistanceRange","FileSource","ComponentsConfiguration","getText","value","exifMapPrototype","prop","getAll","originalHasCanvasOption","originalHasMetaOption","originalTransformCoordinates","originalGetTransformedOptions","orientation","ctx","styleWidth","styleHeight","translate","rotate","PI","opts"],"mappings":"CAaC,SAAWA,GACV,aAKA,SAASC,EAAWC,EAAMC,EAAUC,GAClC,IACIC,EADAC,EAAMC,SAASC,cAAc,OAQjC,OANAF,EAAIG,QAAU,SAAUC,GACtB,OAAOT,EAAUQ,QAAQH,EAAKI,EAAOR,EAAMC,EAAUC,IAEvDE,EAAIK,OAAS,SAAUD,GACrB,OAAOT,EAAUU,OAAOL,EAAKI,EAAOR,EAAMC,EAAUC,IAElC,iBAATF,GACTD,EAAUW,UACRV,EACA,SAAUW,GACJA,GACFX,EAAOW,EACPR,EAAMJ,EAAUa,gBAAgBZ,KAEhCG,EAAMH,EACFE,GAAWA,EAAQW,cACrBT,EAAIS,YAAcX,EAAQW,cAG9BT,EAAIU,IAAMX,GAEZD,GAEKE,GAEPL,EAAUgB,aAAa,OAAQf,IAG/BD,EAAUgB,aAAa,OAAQf,IAE/BG,EAAMC,EAAIY,WAAajB,EAAUa,gBAAgBZ,KAE/CI,EAAIU,IAAMX,EACHC,GAEFL,EAAUkB,SAASjB,EAAM,SAAUkB,GACxC,IAAIC,EAASD,EAAEC,OACXA,GAAUA,EAAOC,OACnBhB,EAAIU,IAAMK,EAAOC,OACRnB,GACTA,EAASiB,UAhBR,EA4BT,SAASG,EAAcjB,EAAKF,IACtBE,EAAIY,YAAgBd,GAAWA,EAAQoB,WACzCvB,EAAUwB,gBAAgBnB,EAAIY,mBACvBZ,EAAIY,YARf,IAAIQ,EACD1B,EAAEc,iBAAmBd,GACrBA,EAAE2B,KAAOA,IAAIF,iBAAmBE,KAChC3B,EAAE4B,WAAaA,UAYlB3B,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7CD,KAGFF,EAAUgB,aAAe,SAAUY,EAAMC,GAEvC,OAAOC,OAAOC,UAAUC,SAASC,KAAKJ,KAAS,WAAaD,EAAO,KAGrE5B,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DjC,EAASG,EAAK8B,IAGhBnC,EAAUQ,QAAU,SAAUH,EAAKI,EAAOR,EAAMC,EAAUC,GACxDmB,EAAajB,EAAKF,GACdD,GACFA,EAAS+B,KAAK5B,EAAKI,IAIvBT,EAAUU,OAAS,SAAUL,EAAKI,EAAOR,EAAMC,EAAUC,GACvDmB,EAAajB,EAAKF,GACdD,GACFF,EAAUkC,UAAU7B,EAAKF,EAASD,EAAUD,OAIhDD,EAAUa,gBAAkB,SAAUZ,GACpC,QAAOwB,GAASA,EAAOZ,gBAAgBZ,IAGzCD,EAAUwB,gBAAkB,SAAUpB,GACpC,QAAOqB,GAASA,EAAOD,gBAAgBpB,IAMzCJ,EAAUkB,SAAW,SAAUjB,EAAMC,EAAUkC,GAC7C,GAAIrC,EAAEsC,WAAY,CAChB,IAAIC,EAAa,IAAID,WAGrB,GAFAC,EAAW5B,OAAS4B,EAAW9B,QAAUN,EACzCkC,EAASA,GAAU,gBACfE,EAAWF,GAEb,OADAE,EAAWF,GAAQnC,GACZqC,EAGX,OAAO,GAGa,mBAAXC,QAAyBA,OAAOC,IACzCD,OAAO,WACL,OAAOvC,IAEkB,iBAAXyC,QAAuBA,OAAOC,QAC9CD,OAAOC,QAAU1C,EAEjBD,EAAEC,UAAYA,EAjIjB,CAmIqB,oBAAX2C,QAA0BA,QAAWC,MCnI/C,SAAWC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI+C,EAAoB/C,EAAUkC,UAElClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DY,EAAkBd,KAChBjC,EACAA,EAAUgD,MAAM3C,EAAKF,EAASgC,GAC9BhC,EACAD,EACAD,EACAkC,IAOJnC,EAAUiD,qBAAuB,aAKjCjD,EAAUkD,sBAAwB,SAAU7C,EAAKF,GAC/C,IACIgD,EACAC,EACAC,EACAC,EAJAC,EAAcpD,EAAQoD,YAK1B,IAAKA,EACH,OAAOpD,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAa5B,OAVAD,EAAWM,MAAO,EAClBJ,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAC9BD,EAAQC,EAASC,GACnBJ,EAAWS,SAAWN,EAASC,EAC/BJ,EAAWU,UAAYP,IAEvBH,EAAWS,SAAWP,EACtBF,EAAWU,UAAYR,EAAQE,GAE1BJ,GAITnD,EAAU8D,oBAAsB,SAC9BC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAeA,OAbAR,EACGS,WAAW,MACXC,UACCpE,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAEGR,GAIT/D,EAAU0E,gBAAkB,SAAUvE,GACpC,OAAOA,EAAQ4D,QAAU5D,EAAQsD,QAAUtD,EAAQoD,aAQrDvD,EAAUgD,MAAQ,SAAU3C,EAAKF,EAASgC,GAqBxC,SAASwC,IACP,IAAI3B,EAAQ4B,KAAKC,KACdC,GAAYR,GAAaA,GACzBS,GAAaR,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GAGlB,SAASgC,IACP,IAAIhC,EAAQ4B,KAAKK,KACdrB,GAAYU,GAAaA,GACzBT,GAAaU,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GArClB7C,EAAUA,MACV,IAQIyD,EACAC,EACAiB,EACAC,EACAb,EACAC,EACAH,EACAC,EACAiB,EACAC,EACAC,EAlBArB,EAASzD,SAASC,cAAc,UAChC8E,EACFhF,EAAImE,YACHxE,EAAU0E,gBAAgBvE,IAAY4D,EAAOS,WAC5CnB,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAClCgB,EAAYjB,EACZkB,EAAajB,EAuFjB,GAvDI+B,IAEFrB,GADA7D,EAAUH,EAAUkD,sBAAsB7C,EAAKF,EAASgC,IACtCmD,MAAQ,EAC1BrB,EAAU9D,EAAQoF,KAAO,EACrBpF,EAAQ+D,aACVA,EAAc/D,EAAQ+D,iBACAsB,IAAlBrF,EAAQsF,YAAwCD,IAAjBrF,EAAQmF,OACzCtB,EAAUX,EAAQa,EAAc/D,EAAQsF,QAG1CvB,EAAcb,EAAQW,GAAW7D,EAAQsF,OAAS,GAEhDtF,EAAQgE,cACVA,EAAehE,EAAQgE,kBACAqB,IAAnBrF,EAAQuF,aAAwCF,IAAhBrF,EAAQoF,MAC1CtB,EAAUX,EAASa,EAAehE,EAAQuF,SAG5CvB,EAAeb,EAASW,GAAW9D,EAAQuF,QAAU,GAEvDpB,EAAYJ,EACZK,EAAaJ,GAEfP,EAAWzD,EAAQyD,SACnBC,EAAY1D,EAAQ0D,UACpBiB,EAAW3E,EAAQ2E,SACnBC,EAAY5E,EAAQ4E,UAChBM,GAAazB,GAAYC,GAAa1D,EAAQsD,MAChDa,EAAYV,EACZW,EAAaV,GACbuB,EAAMlB,EAAcC,EAAeP,EAAWC,GACpC,GACRM,EAAeN,EAAYK,EAAcN,OACrB4B,IAAhBrF,EAAQoF,UAAwCC,IAAnBrF,EAAQuF,SACvCzB,GAAWX,EAASa,GAAgB,IAE7BiB,EAAM,IACflB,EAAcN,EAAWO,EAAeN,OACnB2B,IAAjBrF,EAAQmF,WAAwCE,IAAlBrF,EAAQsF,QACxCzB,GAAWX,EAAQa,GAAe,OAIlC/D,EAAQwF,SAAWxF,EAAQyF,SAC7Bd,EAAWlB,EAAWA,GAAYkB,EAClCC,EAAYlB,EAAYA,GAAakB,GAEnC5E,EAAQyF,OACVZ,IACAL,MAEAA,IACAK,MAGAK,EAAW,CAUb,IATAH,EAAa/E,EAAQ+E,YACJ,IACfnB,EAAO8B,MAAMxC,MAAQiB,EAAY,KACjCP,EAAO8B,MAAMvC,OAASiB,EAAa,KACnCD,GAAaY,EACbX,GAAcW,EACdnB,EAAOS,WAAW,MAAMxB,MAAMkC,EAAYA,KAE5CC,EAAoBhF,EAAQgF,mBAEN,GACpBA,EAAoB,GACpBb,EAAYJ,GACZK,EAAaJ,EAEb,KAAOD,EAAciB,EAAoBb,GACvCP,EAAOV,MAAQa,EAAciB,EAC7BpB,EAAOT,OAASa,EAAegB,EAC/BnF,EAAU8D,oBACRC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAJ,EAAOV,MACPU,EAAOT,QAETU,EAAU,EACVC,EAAU,EACVC,EAAcH,EAAOV,MACrBc,EAAeJ,EAAOT,QACtBjD,EAAMC,SAASC,cAAc,WACzB8C,MAAQa,EACZ7D,EAAIiD,OAASa,EACbnE,EAAU8D,oBACRzD,EACA0D,EACA,EACA,EACAG,EACAC,EACA,EACA,EACAD,EACAC,GAON,OAHAJ,EAAOV,MAAQiB,EACfP,EAAOT,OAASiB,EAChBvE,EAAUiD,qBAAqBc,EAAQ5D,GAChCH,EAAU8D,oBACfC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAG,EACAC,GAKJ,OAFAlE,EAAIgD,MAAQiB,EACZjE,EAAIiD,OAASiB,EACNlE,KCxQV,SAAWwC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI8F,EACc,oBAATC,OACNA,KAAKhE,UAAUiE,OACdD,KAAKhE,UAAUkE,aACfF,KAAKhE,UAAUmE,UAEnBlG,EAAUmG,UACRL,GACA,WAEE,OADYlD,KAAKoD,OAASpD,KAAKqD,aAAerD,KAAKsD,UACtCE,MAAMxD,KAAMyD,YAG7BrG,EAAUsG,iBACRC,MACEC,WAUJxG,EAAUyG,cAAgB,SAAUxG,EAAMC,EAAUC,EAASgC,GAC3DhC,EAAUA,MACVgC,EAAOA,MACP,IAAIuE,EAAO9D,KAEP+D,EAAkBxG,EAAQwG,iBAAmB,UAE3B,oBAAbC,UACP3G,GACAA,EAAK4G,MAAQ,IACC,eAAd5G,EAAK2B,MACL5B,EAAUmG,YAITnG,EAAUkB,SACTlB,EAAUmG,UAAUlE,KAAKhC,EAAM,EAAG0G,GAClC,SAAUxF,GACR,GAAIA,EAAEC,OAAO0F,MAIX,OAFAC,QAAQC,IAAI7F,EAAEC,OAAO0F,YACrB5G,EAASiC,GAOX,IAKI8E,EACAC,EACAC,EACA/D,EARAgE,EAASjG,EAAEC,OAAOC,OAClBgG,EAAW,IAAIT,SAASQ,GACxBE,EAAS,EACTC,EAAYF,EAASG,WAAa,EAClCC,EAAaH,EAMjB,GAA8B,QAA1BD,EAASK,UAAU,GAAe,CACpC,KAAOJ,EAASC,KACdN,EAAcI,EAASK,UAAUJ,KAKf,OAAUL,GAAe,OACzB,QAAhBA,IAPuB,CAcvB,GADAC,EAAeG,EAASK,UAAUJ,EAAS,GAAK,EAC5CA,EAASJ,EAAeG,EAASG,WAAY,CAC/CT,QAAQC,IAAI,4CACZ,MAGF,GADAG,EAAUnH,EAAUsG,gBAAgBC,KAAKU,GAEvC,IAAK7D,EAAI,EAAGA,EAAI+D,EAAQQ,OAAQvE,GAAK,EACnC+D,EAAQ/D,GAAGnB,KACTyE,EACAW,EACAC,EACAJ,EACA/E,EACAhC,GAKNsH,EADAH,GAAUJ,GAUT/G,EAAQyH,kBAAoBH,EAAa,IACxCL,EAAOpB,MACT7D,EAAK0F,UAAYT,EAAOpB,MAAM,EAAGyB,GAIjCtF,EAAK0F,UAAY,IAAIC,WAAWV,GAAQW,SAAS,EAAGN,SAIxDV,QAAQC,IAAI,2CAEd9G,EAASiC,IAEX,sBAGFjC,EAASiC,IAKbnC,EAAUgI,cAAgB,SAAU7H,GAClC,OAAOA,GAAWA,EAAQ8H,MAG5B,IAAIlF,EAAoB/C,EAAUkC,UAClClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GACxDnC,EAAUgI,cAAc7H,GAC1BH,EAAUyG,cACRxG,EACA,SAAUkC,GACRY,EAAkBd,KAAKjC,EAAWK,EAAKF,EAASD,EAAUD,EAAMkC,IAElEhC,EACAgC,GAGFY,EAAkBqD,MAAMpG,EAAWqG,cCjKxC,SAAWxD,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEqB,oBAAVkI,OAA4C,oBAAZC,UACzCnI,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7C,GAAIH,EAAUgI,cAAc7H,GAC1B,OAAO+H,MAAM,IAAIC,QAAQ/H,EAAKD,IAC3BiI,KAAK,SAAUC,GACd,OAAOA,EAASzH,SAEjBwH,KAAKlI,GACLoI,MAAM,SAAUC,GACfxB,QAAQC,IAAIuB,GACZrI,MAGJA,QC3BP,SAAW2C,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAU,WAClB,OAAO5F,MAGT5C,EAAUwI,QAAQzG,UAAU0G,KAC1BC,YAAa,KAGf1I,EAAUwI,QAAQzG,UAAU4G,IAAM,SAAUC,GAC1C,OAAOhG,KAAKgG,IAAOhG,KAAKA,KAAK6F,IAAIG,KAGnC5I,EAAU6I,iBAAmB,SAAUxB,EAAUC,EAAQK,GACvD,GAAKA,KAAUL,EAASK,EAASN,EAASG,YAI1C,OAAOxH,EAAUa,gBACf,IAAIkF,MAAMsB,EAASD,OAAOpB,MAAMsB,EAAQA,EAASK,MAJjDZ,QAAQC,IAAI,+CAQhBhH,EAAU8I,cAERC,GACEC,SAAU,SAAU3B,EAAU4B,GAC5B,OAAO5B,EAAS6B,SAASD,IAE3BpC,KAAM,GAGRsC,GACEH,SAAU,SAAU3B,EAAU4B,GAC5B,OAAOG,OAAOC,aAAahC,EAAS6B,SAASD,KAE/CpC,KAAM,EACNyC,OAAO,GAGTC,GACEP,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASK,UAAUuB,EAAYO,IAExC3C,KAAM,GAGR4C,GACET,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASqC,UAAUT,EAAYO,IAExC3C,KAAM,GAGR8C,GACEX,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASqC,UAAUT,EAAYO,GAC/BnC,EAASqC,UAAUT,EAAa,EAAGO,IAGvC3C,KAAM,GAGR+C,GACEZ,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASwC,SAASZ,EAAYO,IAEvC3C,KAAM,GAGRiD,IACEd,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASwC,SAASZ,EAAYO,GAC9BnC,EAASwC,SAASZ,EAAa,EAAGO,IAGtC3C,KAAM,IAIV7G,EAAU8I,aAAa,GAAK9I,EAAU8I,aAAa,GAEnD9I,EAAU+J,aAAe,SACvB1C,EACA2C,EACA1C,EACA1F,EACA+F,EACA6B,GAEA,IACIS,EACAhB,EACAiB,EACA9G,EACA+G,EACAC,EANAC,EAAUrK,EAAU8I,aAAalH,GAOrC,GAAKyI,EAAL,CAWA,GAPAJ,EAAUI,EAAQxD,KAAOc,KAGzBsB,EACEgB,EAAU,EACND,EAAa3C,EAASqC,UAAUpC,EAAS,EAAGkC,GAC5ClC,EAAS,GACE2C,EAAU5C,EAASG,YAApC,CAIA,GAAe,IAAXG,EACF,OAAO0C,EAAQrB,SAAS3B,EAAU4B,EAAYO,GAGhD,IADAU,KACK9G,EAAI,EAAGA,EAAIuE,EAAQvE,GAAK,EAC3B8G,EAAO9G,GAAKiH,EAAQrB,SAClB3B,EACA4B,EAAa7F,EAAIiH,EAAQxD,KACzB2C,GAGJ,GAAIa,EAAQf,MAAO,CAGjB,IAFAa,EAAM,GAED/G,EAAI,EAAGA,EAAI8G,EAAOvC,QAGX,QAFVyC,EAAIF,EAAO9G,IADkBA,GAAK,EAMlC+G,GAAOC,EAET,OAAOD,EAET,OAAOD,EA3BLnD,QAAQC,IAAI,gDAXZD,QAAQC,IAAI,yCAyChBhH,EAAUsK,aAAe,SACvBjD,EACA2C,EACA1C,EACAkC,EACArH,GAEA,IAAIoI,EAAMlD,EAASK,UAAUJ,EAAQkC,GACrCrH,EAAKqI,KAAKD,GAAOvK,EAAU+J,aACzB1C,EACA2C,EACA1C,EACAD,EAASK,UAAUJ,EAAS,EAAGkC,GAC/BnC,EAASqC,UAAUpC,EAAS,EAAGkC,GAC/BA,IAIJxJ,EAAUyK,cAAgB,SACxBpD,EACA2C,EACAU,EACAlB,EACArH,GAEA,IAAIwI,EAAYC,EAAcxH,EAC9B,GAAIsH,EAAY,EAAIrD,EAASG,WAC3BT,QAAQC,IAAI,oDADd,CAMA,GAFA2D,EAAatD,EAASK,UAAUgD,EAAWlB,MAC3CoB,EAAeF,EAAY,EAAI,GAAKC,GACjB,EAAItD,EAASG,YAAhC,CAIA,IAAKpE,EAAI,EAAGA,EAAIuH,EAAYvH,GAAK,EAC/BR,KAAK0H,aACHjD,EACA2C,EACAU,EAAY,EAAI,GAAKtH,EACrBoG,EACArH,GAIJ,OAAOkF,EAASqC,UAAUkB,EAAcpB,GAbtCzC,QAAQC,IAAI,gDAgBhBhH,EAAU6K,cAAgB,SAAUxD,EAAUC,EAAQK,EAAQxF,EAAMhC,GAClE,IAAIA,EAAQ2K,YAAZ,CAGA,IACItB,EACAkB,EACAK,EAHAf,EAAa1C,EAAS,GAK1B,GAAuC,aAAnCD,EAASqC,UAAUpC,EAAS,GAIhC,GAAI0C,EAAa,EAAI3C,EAASG,WAC5BT,QAAQC,IAAI,iDAId,GAAuC,IAAnCK,EAASK,UAAUJ,EAAS,GAAhC,CAKA,OAAQD,EAASK,UAAUsC,IACzB,KAAK,MACHR,GAAe,EACf,MACF,KAAK,MACHA,GAAe,EACf,MACF,QAEE,YADAzC,QAAQC,IAAI,qDAIyC,KAArDK,EAASK,UAAUsC,EAAa,EAAGR,IAKvCkB,EAAYrD,EAASqC,UAAUM,EAAa,EAAGR,GAE/CrH,EAAKqI,KAAO,IAAIxK,EAAUwI,SAG1BkC,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACArH,MAEgBhC,EAAQ6K,uBACxBD,GAAkBP,SAClBE,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACAuB,GAGEA,EAAcP,KAAK,OACrBrI,EAAKqI,KAAKS,UAAYjL,EAAU6I,iBAC9BxB,EACA2C,EAAae,EAAcP,KAAK,KAChCO,EAAcP,KAAK,QAKrBrI,EAAKqI,KAAK,SAAYrK,EAAQ+K,gBAChClL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,GAIAA,EAAKqI,KAAK,SAAYrK,EAAQgL,gBAChCnL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,IAnDF4E,QAAQC,IAAI,gDAjBZD,QAAQC,IAAI,uDA0EhBhH,EAAUsG,gBAAgBC,KAAK,OAAQ6E,KAAKpL,EAAU6K,iBCrSvD,SAAWhI,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAQzG,UAAUsJ,MAI1BC,IAAQ,aACRC,IAAQ,cACRC,MAAQ,iBACRC,MAAQ,oBACRC,MAAQ,6BACRC,IAAQ,gBACRC,IAAQ,cACRC,IAAQ,4BACRC,IAAQ,cACRC,IAAQ,kBACRC,IAAQ,sBACRC,IAAQ,mBACRC,IAAQ,mBACRC,IAAQ,cACRC,IAAQ,cACRC,IAAQ,iBACRC,IAAQ,eACRC,IAAQ,eACRC,IAAQ,kBACRC,IAAQ,wBACRC,IAAQ,8BACRC,IAAQ,mBACRC,IAAQ,aACRC,IAAQ,wBACRC,IAAQ,oBACRC,IAAQ,sBACRC,IAAQ,WACRC,IAAQ,mBACRC,IAAQ,OACRC,IAAQ,QACRC,IAAQ,WACRC,IAAQ,SACRC,MAAQ,YAIRC,MAAQ,cACRC,MAAQ,kBACRC,MAAQ,aACRC,MAAQ,kBACRC,MAAQ,kBACRC,MAAQ,QACRC,MAAQ,0BACRC,MAAQ,yBACRC,MAAQ,YACRC,MAAQ,cACRC,MAAQ,mBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,aACRC,MAAQ,qBACRC,MAAQ,sBACRC,MAAQ,eACRC,MAAQ,UACRC,MAAQ,kBACRC,MAAQ,sBACRC,MAAQ,0BACRC,MAAQ,OACRC,MAAQ,kBACRC,MAAQ,4BACRC,MAAQ,2BACRC,MAAQ,WACRC,MAAQ,sBACRC,MAAQ,sBACRC,MAAQ,oBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,cACRC,MAAQ,QACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,2BACRC,MAAQ,wBACRC,MAAQ,wBACRC,MAAQ,2BACRC,MAAQ,kBACRC,MAAQ,gBACRC,MAAQ,gBACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,aACRC,MAAQ,iBACRC,MAAQ,eACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,wBACRC,MAAQ,mBACRC,MAAQ,cACRC,MAAQ,WACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,2BACRC,MAAQ,uBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,WACRC,MAAQ,YACRC,MAAQ,mBAIRC,EAAQ,eACR7I,EAAQ,iBACRI,EAAQ,cACRI,EAAQ,kBACRE,EAAQ,eACRE,EAAQ,iBACRkI,EAAQ,cACRC,EAAQ,eACRC,EAAQ,gBACRnI,EAAQ,YACRE,GAAQ,iBACRkI,GAAQ,SACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,cACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,mBACRC,GAAQ,oBACRC,GAAQ,iBACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,qBACRC,GAAQ,eACRC,GAAQ,kBACRC,GAAQ,wBAGVpT,EAAUwI,QAAQzG,UAAUsR,cAC1BC,iBACE1B,EAAG,YACH7I,EAAG,SACHI,EAAG,iBACHI,EAAG,oBACHE,EAAG,mBACHE,EAAG,mBACHkI,EAAG,iBACHC,EAAG,gBACHC,EAAG,kBAELwB,cACE3B,EAAG,UACH7I,EAAG,UACHI,EAAG,wBACHI,EAAG,OACHE,EAAG,YACHE,EAAG,UACHkI,EAAG,UACH2B,IAAK,SAEPC,aACE7B,EAAG,UACH7I,EAAG,WACHI,EAAG,cACHI,EAAG,gCACHE,EAAG,QACHG,EAAG,eACHE,GAAI,iBACJkI,GAAI,QACJC,GAAI,wCACJC,GAAI,yCACJC,GAAI,0CACJC,GAAI,sCACJE,GAAI,mBACJC,GAAI,mBACJC,GAAI,mBACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,sBACJW,IAAK,SAEPE,OACE9B,EAAQ,qBACR7I,EAAQ,cACRY,EAAQ,mCACRmI,EAAQ,+BACRlI,EAAQ,qCACRsI,GAAQ,gEACRE,GAAQ,4DACRC,GAAQ,4CACRQ,GAAQ,gCACRC,GAAQ,yBACRI,GAAQ,oDACRE,GAAQ,gDACRO,GAAQ,oBACRC,GAAQ,sCACRC,GAAQ,iEACRC,GAAQ,6DACRC,GAAQ,6DACRC,GAAQ,wFACRC,GAAQ,oFACRC,GAAQ,iDACRC,GAAQ,4EACRC,GAAQ,yEAEVC,eACEtL,EAAG,YACHI,EAAG,6BACHI,EAAG,6BACHE,EAAG,+BACHE,EAAG,+BACHmI,EAAG,mBACHC,EAAG,kCAELuC,kBACE1C,EAAG,WACH7I,EAAG,YACHI,EAAG,WACHI,EAAG,eAELgL,WACExL,EAAG,yBAELyL,gBACE5C,EAAG,iBACH7I,EAAG,kBAEL0L,cACE7C,EAAG,qBACH7I,EAAG,wBAEL2L,aACE9C,EAAG,OACH7I,EAAG,cACHI,EAAG,eACHI,EAAG,gBACHE,EAAG,kBAELkL,UACE/C,EAAG,SACH7I,EAAG,OACHI,EAAG,QAELyL,YACEhD,EAAG,SACH7I,EAAG,iBACHI,EAAG,mBAEL0L,WACEjD,EAAG,SACH7I,EAAG,OACHI,EAAG,QAEL2L,sBACElD,EAAG,UACH7I,EAAG,QACHI,EAAG,aACHI,EAAG,gBAELwL,YACExL,EAAG,OAELyL,yBACEpD,EAAG,GACH7I,EAAG,IACHI,EAAG,KACHI,EAAG,KACHE,EAAG,IACHE,EAAG,IACHkI,EAAG,KAELnJ,aACEK,EAAG,WACHI,EAAG,YACHI,EAAG,eACHE,EAAG,cACHE,EAAG,WACHkI,EAAG,YACHC,EAAG,eACHC,EAAG,gBAIP/R,EAAUwI,QAAQzG,UAAUkT,QAAU,SAAUrM,GAC9C,IAAIsM,EAAQtS,KAAK+F,IAAIC,GACrB,OAAQA,GACN,IAAK,cACL,IAAK,QACL,IAAK,eACL,IAAK,kBACL,IAAK,gBACL,IAAK,mBACL,IAAK,YACL,IAAK,iBACL,IAAK,eACL,IAAK,cACL,IAAK,WACL,IAAK,aACL,IAAK,YACL,IAAK,uBACL,IAAK,aACL,IAAK,cACH,OAAOhG,KAAKyQ,aAAazK,GAAIsM,GAC/B,IAAK,cACL,IAAK,kBACH,IAAKA,EAAO,OACZ,OAAO9L,OAAOC,aAAa6L,EAAM,GAAIA,EAAM,GAAIA,EAAM,GAAIA,EAAM,IACjE,IAAK,0BACH,IAAKA,EAAO,OACZ,OACEtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAEhC,IAAK,eACH,IAAKA,EAAO,OACZ,OAAOA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAEpE,OAAO9L,OAAO8L,IAEf,SAAWC,GACV,IAEIC,EAFA/J,EAAO8J,EAAiB9J,KACxB5C,EAAM0M,EAAiB1M,IAG3B,IAAK2M,KAAQ/J,EACPA,EAAK7H,eAAe4R,KACtB3M,EAAI4C,EAAK+J,IAASA,GAPvB,CAUEpV,EAAUwI,QAAQzG,WAErB/B,EAAUwI,QAAQzG,UAAUsT,OAAS,WACnC,IACID,EACAxM,EAFAH,KAGJ,IAAK2M,KAAQxS,KACPA,KAAKY,eAAe4R,KACtBxM,EAAKhG,KAAKyI,KAAK+J,MAEb3M,EAAIG,GAAMhG,KAAKqS,QAAQrM,IAI7B,OAAOH,KCpXV,SAAW5F,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsB,qBAAsBM,GACzC,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EACEC,QAAQ,gBACRA,QAAQ,sBACRA,QAAQ,sBAIVD,EAAQF,OAAO3C,WAblB,CAeE,SAAUA,GACX,aAEA,IAAIsV,EAA0BtV,EAAU0E,gBACpC6Q,EAAwBvV,EAAUgI,cAClCwN,EAA+BxV,EAAUiD,qBACzCwS,EAAgCzV,EAAUkD,sBAG9ClD,EAAU0E,gBAAkB,SAAUvE,GACpC,QACIA,EAAQuV,aAAeJ,EAAwBrT,KAAKjC,EAAWG,IAKrEH,EAAUgI,cAAgB,SAAU7H,GAClC,OACGA,IAAmC,IAAxBA,EAAQuV,aACpBH,EAAsBtT,KAAKjC,EAAWG,IAM1CH,EAAUiD,qBAAuB,SAAUc,EAAQ5D,GACjDqV,EAA6BvT,KAAKjC,EAAW+D,EAAQ5D,GACrD,IAAIwV,EAAM5R,EAAOS,WAAW,MACxBnB,EAAQU,EAAOV,MACfC,EAASS,EAAOT,OAChBsS,EAAa7R,EAAO8B,MAAMxC,MAC1BwS,EAAc9R,EAAO8B,MAAMvC,OAC3BoS,EAAcvV,EAAQuV,YAC1B,GAAKA,KAAeA,EAAc,GASlC,OANIA,EAAc,IAChB3R,EAAOV,MAAQC,EACfS,EAAOT,OAASD,EAChBU,EAAO8B,MAAMxC,MAAQwS,EACrB9R,EAAO8B,MAAMvC,OAASsS,GAEhBF,GACN,KAAK,EAEHC,EAAIG,UAAUzS,EAAO,GACrBsS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAIG,UAAUzS,EAAOC,GACrBqS,EAAII,OAAOnR,KAAKoR,IAChB,MACF,KAAK,EAEHL,EAAIG,UAAU,EAAGxS,GACjBqS,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAU,GAAIxS,GAClB,MACF,KAAK,EAEHqS,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAUzS,GAAQC,GACtBqS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAII,QAAQ,GAAMnR,KAAKoR,IACvBL,EAAIG,WAAWzS,EAAO,KAO5BrD,EAAUkD,sBAAwB,SAAU7C,EAAK4V,EAAM9T,GACrD,IAEIgB,EACAC,EAHAjD,EAAUsV,EAA8BxT,KAAKjC,EAAWK,EAAK4V,GAC7DP,EAAcvV,EAAQuV,YAM1B,IAHoB,IAAhBA,GAAwBvT,GAAQA,EAAKqI,OACvCkL,EAAcvT,EAAKqI,KAAK7B,IAAI,iBAEzB+M,GAAeA,EAAc,GAAqB,IAAhBA,EACrC,OAAOvV,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAI5B,OADAD,EAAWuS,YAAcA,EACjBA,GACN,KAAK,EAEHvS,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWsC,MAAQtF,EAAQmF,KAC3B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWsC,MAAQtF,EAAQmF,KAC3BnC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQsF,MAC5B,MACF,KAAK,EAEHtC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQsF,MAWhC,OARItC,EAAWuS,YAAc,IAC3BvS,EAAWS,SAAWzD,EAAQ0D,UAC9BV,EAAWU,UAAY1D,EAAQyD,SAC/BT,EAAW2B,SAAW3E,EAAQ4E,UAC9B5B,EAAW4B,UAAY5E,EAAQ2E,SAC/B3B,EAAWe,YAAc/D,EAAQgE,aACjChB,EAAWgB,aAAehE,EAAQ+D,aAE7Bf"}
2 |
--------------------------------------------------------------------------------
/demo/web/media/DownloadIcon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/DownloadIcon.png
--------------------------------------------------------------------------------
/demo/web/media/EditIcon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/EditIcon.png
--------------------------------------------------------------------------------
/demo/web/media/beyonce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/beyonce.png
--------------------------------------------------------------------------------
/demo/web/media/cersei.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/cersei.png
--------------------------------------------------------------------------------
/demo/web/media/geoff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/geoff.png
--------------------------------------------------------------------------------
/demo/web/media/john.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/john.png
--------------------------------------------------------------------------------
/demo/web/media/lena.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/lena.png
--------------------------------------------------------------------------------
/demo/web/media/leo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/leo.png
--------------------------------------------------------------------------------
/demo/web/media/loading.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/loading.png
--------------------------------------------------------------------------------
/demo/web/media/louis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/louis.png
--------------------------------------------------------------------------------
/demo/web/media/neil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/neil.png
--------------------------------------------------------------------------------
/demo/web/media/placeholder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder.png
--------------------------------------------------------------------------------
/demo/web/media/placeholder2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder2.png
--------------------------------------------------------------------------------
/demo/web/media/placeholder4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder4.png
--------------------------------------------------------------------------------
/demo/web/media/rashida.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/rashida.png
--------------------------------------------------------------------------------
/demo/web/media/seth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/seth.png
--------------------------------------------------------------------------------
/demo/web/media/steve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/steve.png
--------------------------------------------------------------------------------
/demo/web/mock.css:
--------------------------------------------------------------------------------
1 | /* mock.css
2 | *
3 | * CSS for the mock-up mimicking the article around the demo.
4 | */
5 |
6 | /* Global Rules */
7 |
8 | html {
9 | font-family: Lato, Helvetica, Arial, sans-serif;
10 | text-rendering: optimizeLegibility;
11 | }
12 |
13 | body {
14 | margin: 0em;
15 | font-size: 1.1em;
16 | line-height: 1.5em;
17 | }
18 |
19 | /* Article Title */
20 |
21 | div.TitlePanel {
22 | background: rgb(52,51,57);
23 | background: linear-gradient(349deg, rgb(85, 83, 95) 0%,
24 | rgb(139, 139, 163) 100%);
25 | color: rgb(255, 255, 255);
26 | display: flex;
27 | height: 337px;
28 | }
29 |
30 | div.Title {
31 | margin: auto;
32 | max-width: 555px;
33 | text-align: center;
34 | }
35 |
36 | .Title h1 {
37 | font-size: 2.4em;
38 | line-height: 1.2em;
39 | margin-bottom: 0.3em;
40 | }
41 |
42 | .Title time {
43 | text-transform: uppercase;
44 | color: #cacaca;
45 | font-size: 0.7em;
46 | font-weight: 700;
47 | display: block;
48 | text-align: center;
49 | }
50 |
51 | /* Article Content */
52 |
53 | div.Content {
54 | color: #111;
55 | max-width: 570px;
56 | padding: 1em;
57 | margin: 3em auto;
58 | word-wrap: break-word;
59 | }
60 |
61 | .MockUpNotice {
62 | color: #b32e2e;
63 | }
64 |
65 | .Content p {
66 | margin: 2em auto;
67 | }
68 | .Content h1, h2, h3, h4, h5, h6 {
69 | font-size: 1.3em;
70 | margin-top: 2.5em;
71 | margin-bottom: 0.5em;
72 | }
--------------------------------------------------------------------------------
/graphics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import time
4 | import threading
5 |
6 |
7 | def save_image(x, path):
8 | im = Image.fromarray(x)
9 | im.save(path, optimize=True)
10 | return
11 |
12 | # Assumes [NCHW] format
13 | def save_raster(x, path, rescale=False, width=None):
14 | t = threading.Thread(target=_save_raster, args=(x, path, rescale, width))
15 | t.start()
16 |
17 |
18 | def _save_raster(x, path, rescale, width):
19 | x = to_raster(x, rescale, width)
20 | save_image(x, path)
21 |
22 | # Shape: (n_patches,rows,columns,channels)
23 | def to_raster_old(x, rescale=False, width=None):
24 | x = np.transpose(x, (0, 3, 1, 2))
25 |
26 | #x = x.swapaxes(2, 3)
27 | if len(x.shape) == 3:
28 | x = x.reshape((x.shape[0], 1, x.shape[1], x.shape[2]))
29 | if x.shape[1] == 1:
30 | x = np.repeat(x, 3, axis=1)
31 | if rescale:
32 | x = (x - x.min()) / (x.max() - x.min()) * 255.
33 | x = np.clip(x, 0, 255)
34 | assert len(x.shape) == 4
35 | assert x.shape[1] == 3
36 | n_patches = x.shape[0]
37 | if width is None:
38 | width = int(np.ceil(np.sqrt(n_patches))) # result width
39 | height = int(n_patches/width) # result height
40 | tile_height = x.shape[2]
41 | tile_width = x.shape[3]
42 | result = np.zeros((3, int(height*tile_height),
43 | int(width*tile_width)), dtype='uint8')
44 | for i in range(height):
45 | for j in range(width):
46 | result[:, i*tile_height:(i+1)*tile_height,
47 | j*tile_width:(j+1)*tile_width] = x[i]
48 | return result
49 |
50 |
51 | # Shape: (n_patches,rows,columns,channels)
52 | def to_raster(x, rescale=False, width=None):
53 | if len(x.shape) == 3:
54 | x = x.reshape((x.shape[0], x.shape[1], x.shape[2], 1))
55 | if x.shape[3] == 1:
56 | x = np.repeat(x, 3, axis=3)
57 | if rescale:
58 | x = (x - x.min()) / (x.max() - x.min()) * 255.
59 | x = np.clip(x, 0, 255)
60 | assert len(x.shape) == 4
61 | assert x.shape[3] == 3
62 | n_batch = x.shape[0]
63 | if width is None:
64 | width = int(np.ceil(np.sqrt(n_batch))) # result width
65 | height = int(n_batch / width) # result height
66 | tile_height = x.shape[1]
67 | tile_width = x.shape[2]
68 | result = np.zeros((int(height * tile_height),
69 | int(width * tile_width), 3), dtype='uint8')
70 | for i in range(height):
71 | for j in range(width):
72 | result[i * tile_height:(i + 1) * tile_height, j *
73 | tile_width:(j + 1) * tile_width] = x[width*i+j]
74 | return result
75 |
--------------------------------------------------------------------------------
/memory_saving_gradients.py:
--------------------------------------------------------------------------------
1 | from toposort import toposort
2 | import contextlib
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow.contrib.graph_editor as ge
6 | import time
7 | import sys
8 | sys.setrecursionlimit(10000)
9 | # refers back to current module if we decide to split helpers out
10 | util = sys.modules[__name__]
11 |
12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated"
13 | setattr(tf.GraphKeys, "VARIABLES", "variables")
14 |
15 | # save original gradients since tf.gradient could be monkey-patched to point
16 | # to our version
17 | from tensorflow.python.ops import gradients as tf_gradients_lib
18 | tf_gradients = tf_gradients_lib.gradients
19 |
20 | MIN_CHECKPOINT_NODE_SIZE = 1024 # use lower value during testing
21 |
22 | # specific versions we can use to do process-wide replacement of tf.gradients
23 |
24 |
25 | def gradients_speed(ys, xs, grad_ys=None, **kwargs):
26 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs)
27 |
28 |
29 | def gradients_memory(ys, xs, grad_ys=None, **kwargs):
30 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs)
31 |
32 |
33 | def gradients_collection(ys, xs, grad_ys=None, **kwargs):
34 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs)
35 |
36 |
37 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
38 | '''
39 | Authors: Tim Salimans & Yaroslav Bulatov
40 |
41 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
42 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)
43 |
44 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
45 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)
46 |
47 | 'checkpoints' can either be
48 | - a list consisting of tensors from the forward pass of the neural net
49 | that we should re-use when calculating the gradients in the backward pass
50 | all other tensors that do not appear in this list will be re-computed
51 | - a string specifying how this list should be determined. currently we support
52 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
53 | so checkpointing them maximizes the running speed
54 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
55 | - 'memory': try to minimize the memory usage
56 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
57 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
58 | '''
59 |
60 | # print("Calling memsaving gradients with", checkpoints)
61 | if not isinstance(ys, list):
62 | ys = [ys]
63 | if not isinstance(xs, list):
64 | xs = [xs]
65 |
66 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
67 | inclusive=True)
68 |
69 | debug_print("bwd_ops: %s", bwd_ops)
70 |
71 | # forward ops are all ops that are candidates for recomputation
72 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
73 | inclusive=True,
74 | within_ops=bwd_ops)
75 | debug_print("fwd_ops: %s", fwd_ops)
76 |
77 | # exclude ops with no inputs
78 | fwd_ops = [op for op in fwd_ops if op.inputs]
79 |
80 | # don't recompute xs, remove variables
81 | xs_ops = _to_ops(xs)
82 | fwd_ops = [op for op in fwd_ops if not op in xs_ops]
83 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
84 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
85 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
86 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
87 | ts_all = [t for t in ts_all if '/read' not in t.name]
88 | ts_all = set(ts_all) - set(xs) - set(ys)
89 |
90 | # construct list of tensors to checkpoint during forward pass, if not
91 | # given as input
92 | if type(checkpoints) is not list:
93 | if checkpoints == 'collection':
94 | checkpoints = tf.get_collection('checkpoints')
95 |
96 | elif checkpoints == 'speed':
97 | # checkpoint all expensive ops to maximize running speed
98 | checkpoints = ge.filter_ts_from_regex(
99 | fwd_ops, 'conv2d|Conv|MatMul')
100 |
101 | elif checkpoints == 'memory':
102 |
103 | # remove very small tensors and some weird ops
104 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually
105 | try:
106 | return [int(e if e.value is not None else 64) for e in t]
107 | except:
108 | return [0] # unknown shape
109 | ts_all = [t for t in ts_all if np.prod(
110 | fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
111 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
112 | ts_all = [t for t in ts_all if 'entropy' not in t.name]
113 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
114 | ts_all = [t for t in ts_all if 'Switch' not in t.name]
115 | ts_all = [t for t in ts_all if 'dropout' not in t.name]
116 |
117 | # filter out all tensors that are inputs of the backward graph
118 | with util.capture_ops() as bwd_ops:
119 | tf_gradients(ys, xs, grad_ys, **kwargs)
120 |
121 | bwd_inputs = [t for op in bwd_ops for t in op.inputs]
122 | # list of tensors in forward graph that is in input to bwd graph
123 | ts_filtered = list(set(bwd_inputs).intersection(ts_all))
124 | debug_print("Using tensors %s", ts_filtered)
125 |
126 | # try two slightly different ways of getting bottlenecks tensors
127 | # to checkpoint
128 | for ts in [ts_filtered, ts_all]:
129 |
130 | # get all bottlenecks in the graph
131 | bottleneck_ts = []
132 | for t in ts:
133 | b = set(ge.get_backward_walk_ops(
134 | t.op, inclusive=True, within_ops=fwd_ops))
135 | f = set(ge.get_forward_walk_ops(
136 | t.op, inclusive=False, within_ops=fwd_ops))
137 | # check that there are not shortcuts
138 | b_inp = set(
139 | [inp for op in b for inp in op.inputs]).intersection(ts_all)
140 | f_inp = set(
141 | [inp for op in f for inp in op.inputs]).intersection(ts_all)
142 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
143 | bottleneck_ts.append(t) # we have a bottleneck!
144 | else:
145 | debug_print("Rejected bottleneck candidate and ops %s", [
146 | t] + list(set(ts_all) - set(b_inp) - set(f_inp)))
147 |
148 | # success? or try again without filtering?
149 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found!
150 | break
151 |
152 | if not bottleneck_ts:
153 | raise Exception(
154 | 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".')
155 |
156 | # sort the bottlenecks
157 | bottlenecks_sorted_lists = tf_toposort(
158 | bottleneck_ts, within_ops=fwd_ops)
159 | sorted_bottlenecks = [
160 | t for ts in bottlenecks_sorted_lists for t in ts]
161 |
162 | # save an approximately optimal number ~ sqrt(N)
163 | N = len(ts_filtered)
164 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
165 | checkpoints = sorted_bottlenecks
166 | else:
167 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
168 | checkpoints = sorted_bottlenecks[step::step]
169 |
170 | else:
171 | raise Exception(
172 | '%s is unsupported input for "checkpoints"' % (checkpoints,))
173 |
174 | checkpoints = list(set(checkpoints).intersection(ts_all))
175 |
176 | # at this point automatic selection happened and checkpoints is list of nodes
177 | assert isinstance(checkpoints, list)
178 |
179 | debug_print("Checkpoint nodes used: %s", checkpoints)
180 | # better error handling of special cases
181 | # xs are already handled as checkpoint nodes, so no need to include them
182 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
183 | if xs_intersect_checkpoints:
184 | debug_print("Warning, some input nodes are also checkpoint nodes: %s",
185 | xs_intersect_checkpoints)
186 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
187 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
188 | ys_intersect_checkpoints)
189 | # saving an output node (ys) gives no benefit in memory while creating
190 | # new edge cases, exclude them
191 | if ys_intersect_checkpoints:
192 | debug_print("Warning, some output nodes are also checkpoints nodes: %s",
193 | format_ops(ys_intersect_checkpoints))
194 |
195 | # remove initial and terminal nodes from checkpoints list if present
196 | checkpoints = list(set(checkpoints) - set(ys) - set(xs))
197 |
198 | # check that we have some nodes to checkpoint
199 | if not checkpoints:
200 | raise Exception('no checkpoints nodes found or given as input! ')
201 |
202 | # disconnect dependencies between checkpointed tensors
203 | checkpoints_disconnected = {}
204 | for x in checkpoints:
205 | if x.op and x.op.name is not None:
206 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
207 | else:
208 | grad_node = tf.stop_gradient(x)
209 | checkpoints_disconnected[x] = grad_node
210 |
211 | # partial derivatives to the checkpointed tensors and xs
212 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
213 | stop_at_ts=checkpoints, within_ops=fwd_ops)
214 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
215 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
216 | debug_print("ops_to_copy = %s", ops_to_copy)
217 | debug_print("Processing list %s", ys)
218 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
219 | copied_ops = info._transformed_ops.values()
220 | debug_print("Copied %s to %s", ops_to_copy, copied_ops)
221 | ge.reroute_ts(checkpoints_disconnected.values(),
222 | checkpoints_disconnected.keys(), can_modify=copied_ops)
223 | debug_print("Rewired %s in place of %s restricted to %s",
224 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)
225 |
226 | # get gradients with respect to current boundary + original x's
227 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
228 | boundary = list(checkpoints_disconnected.values())
229 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
230 | debug_print("Got gradients %s", dv)
231 | debug_print("for %s", copied_ys)
232 | debug_print("with respect to %s", boundary+xs)
233 |
234 | inputs_to_do_before = [y.op for y in ys]
235 | if grad_ys is not None:
236 | inputs_to_do_before += grad_ys
237 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
238 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
239 |
240 | # partial derivatives to the checkpointed nodes
241 | # dictionary of "node: backprop" for nodes in the boundary
242 | d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
243 | dv[:len(checkpoints_disconnected)])}
244 | # partial derivatives to xs (usually the params of the neural net)
245 | d_xs = dv[len(checkpoints_disconnected):]
246 |
247 | # incorporate derivatives flowing through the checkpointed nodes
248 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
249 | for ts in checkpoints_sorted_lists[::-1]:
250 | debug_print("Processing list %s", ts)
251 | checkpoints_other = [r for r in checkpoints if r not in ts]
252 | checkpoints_disconnected_other = [
253 | checkpoints_disconnected[r] for r in checkpoints_other]
254 |
255 | # copy part of the graph below current checkpoint node, stopping at
256 | # other checkpoints nodes
257 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[
258 | r.op for r in ts], stop_at_ts=checkpoints_other)
259 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
260 | len(ops_to_copy), fwd_ops, [r.op for r in ts],
261 | checkpoints_other)
262 | debug_print("ops_to_copy = %s", ops_to_copy)
263 | if not ops_to_copy: # we're done!
264 | break
265 | copied_sgv, info = ge.copy_with_input_replacements(
266 | ge.sgv(ops_to_copy), {})
267 | copied_ops = info._transformed_ops.values()
268 | debug_print("Copied %s to %s", ops_to_copy, copied_ops)
269 | ge.reroute_ts(checkpoints_disconnected_other,
270 | checkpoints_other, can_modify=copied_ops)
271 | debug_print("Rewired %s in place of %s restricted to %s",
272 | checkpoints_disconnected_other, checkpoints_other, copied_ops)
273 |
274 | # gradient flowing through the checkpointed node
275 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
276 | substitute_backprops = [d_checkpoints[r] for r in ts]
277 | dv = tf_gradients(boundary,
278 | checkpoints_disconnected_other+xs,
279 | grad_ys=substitute_backprops, **kwargs)
280 | debug_print("Got gradients %s", dv)
281 | debug_print("for %s", boundary)
282 | debug_print("with respect to %s", checkpoints_disconnected_other+xs)
283 | debug_print("with boundary backprop substitutions %s",
284 | substitute_backprops)
285 |
286 | inputs_to_do_before = [d_checkpoints[r].op for r in ts]
287 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
288 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
289 |
290 | # partial derivatives to the checkpointed nodes
291 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
292 | if dr is not None:
293 | if d_checkpoints[r] is None:
294 | d_checkpoints[r] = dr
295 | else:
296 | d_checkpoints[r] += dr
297 |
298 | # partial derivatives to xs (usually the params of the neural net)
299 | d_xs_new = dv[len(checkpoints_other):]
300 | for j in range(len(xs)):
301 | if d_xs_new[j] is not None:
302 | if d_xs[j] is None:
303 | d_xs[j] = d_xs_new[j]
304 | else:
305 | d_xs[j] += d_xs_new[j]
306 |
307 | return d_xs
308 |
309 |
310 | def tf_toposort(ts, within_ops=None):
311 | all_ops = ge.get_forward_walk_ops(
312 | [x.op for x in ts], within_ops=within_ops)
313 |
314 | deps = {}
315 | for op in all_ops:
316 | for o in op.outputs:
317 | deps[o] = set(op.inputs)
318 | sorted_ts = toposort(deps)
319 |
320 | # only keep the tensors from our original list
321 | ts_sorted_lists = []
322 | for l in sorted_ts:
323 | keep = list(set(l).intersection(ts))
324 | if keep:
325 | ts_sorted_lists.append(keep)
326 |
327 | return ts_sorted_lists
328 |
329 |
330 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts):
331 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts))
332 | ops = bwd_ops.intersection(within_ops).difference(
333 | [t.op for t in stop_at_ts])
334 | return list(ops)
335 |
336 |
337 | @contextlib.contextmanager
338 | def capture_ops():
339 | """Decorator to capture ops created in the block.
340 | with capture_ops() as ops:
341 | # create some ops
342 | print(ops) # => prints ops created.
343 | """
344 |
345 | micros = int(time.time()*10**6)
346 | scope_name = str(micros)
347 | op_list = []
348 | with tf.name_scope(scope_name):
349 | yield op_list
350 |
351 | g = tf.get_default_graph()
352 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g))
353 |
354 |
355 | def _to_op(tensor_or_op):
356 | if hasattr(tensor_or_op, "op"):
357 | return tensor_or_op.op
358 | return tensor_or_op
359 |
360 |
361 | def _to_ops(iterable):
362 | if not _is_iterable(iterable):
363 | return iterable
364 | return [_to_op(i) for i in iterable]
365 |
366 |
367 | def _is_iterable(o):
368 | try:
369 | _ = iter(o)
370 | except Exception:
371 | return False
372 | return True
373 |
374 |
375 | DEBUG_LOGGING = False
376 |
377 |
378 | def debug_print(s, *args):
379 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their
380 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug
381 |
382 | Usage:
383 | debug_print("see tensors %s for %s", tensorlist, [1,2,3])
384 | """
385 |
386 | if DEBUG_LOGGING:
387 | formatted_args = [format_ops(arg) for arg in args]
388 | print("DEBUG "+s % tuple(formatted_args))
389 |
390 |
391 | def format_ops(ops, sort_outputs=True):
392 | """Helper method for printing ops. Converts Tensor/Operation op to op.name,
393 | rest to str(op)."""
394 |
395 | if hasattr(ops, '__iter__') and not isinstance(ops, str):
396 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops]
397 | if sort_outputs:
398 | return sorted(l)
399 | return l
400 | else:
401 | return ops.name if hasattr(ops, "name") else str(ops)
402 |
403 |
404 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before):
405 | for op in wait_to_do_ops:
406 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs]
407 | ge.add_control_inputs(op, ci)
408 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import tfops as Z
4 | import optim
5 | import numpy as np
6 | import horovod.tensorflow as hvd
7 | from tensorflow.contrib.framework.python.ops import add_arg_scope
8 |
9 |
10 | '''
11 | f_loss: function with as input the (x,y,reuse=False), and as output a list/tuple whose first element is the loss.
12 | '''
13 |
14 |
15 | def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):
16 |
17 | # == Create class with static fields and methods
18 | class m(object):
19 | pass
20 | m.sess = sess
21 | m.feeds = feeds
22 | m.lr = lr
23 |
24 | # === Loss and optimizer
25 | loss_train, stats_train = f_loss(train_iterator, True)
26 | all_params = tf.trainable_variables()
27 | if hps.gradient_checkpointing == 1:
28 | from memory_saving_gradients import gradients
29 | gs = gradients(loss_train, all_params)
30 | else:
31 | gs = tf.gradients(loss_train, all_params)
32 |
33 | optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
34 | 'adam2': optim.adam2}[hps.optimizer]
35 |
36 | train_op, polyak_swap_op, ema = optimizer(
37 | all_params, gs, alpha=lr, hps=hps)
38 | if hps.direct_iterator:
39 | m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
40 | else:
41 | def _train(_lr):
42 | _x, _y = train_iterator()
43 | return sess.run([train_op, stats_train], {feeds['x']: _x,
44 | feeds['y']: _y, lr: _lr})[1]
45 | m.train = _train
46 |
47 | m.polyak_swap = lambda: sess.run(polyak_swap_op)
48 |
49 | # === Testing
50 | loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
51 | if hps.direct_iterator:
52 | m.test = lambda: sess.run(stats_test)
53 | else:
54 | def _test():
55 | _x, _y = test_iterator()
56 | return sess.run(stats_test, {feeds['x']: _x,
57 | feeds['y']: _y})
58 | m.test = _test
59 |
60 | # === Saving and restoring
61 | saver = tf.train.Saver()
62 | saver_ema = tf.train.Saver(ema.variables_to_restore())
63 | m.save_ema = lambda path: saver_ema.save(
64 | sess, path, write_meta_graph=False)
65 | m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
66 | m.restore = lambda path: saver.restore(sess, path)
67 |
68 | # === Initialize the parameters
69 | if hps.restore_path != '':
70 | m.restore(hps.restore_path)
71 | else:
72 | with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
73 | results_init = f_loss(None, True, reuse=True)
74 | sess.run(tf.global_variables_initializer())
75 | sess.run(results_init, {feeds['x']: data_init['x'],
76 | feeds['y']: data_init['y']})
77 | sess.run(hvd.broadcast_global_variables(0))
78 |
79 | return m
80 |
81 |
82 | def codec(hps):
83 |
84 | def encoder(z, objective):
85 | eps = []
86 | for i in range(hps.n_levels):
87 | z, objective = revnet2d(str(i), z, objective, hps)
88 | if i < hps.n_levels-1:
89 | z, objective, _eps = split2d("pool"+str(i), z, objective=objective)
90 | eps.append(_eps)
91 | return z, objective, eps
92 |
93 | def decoder(z, eps=[None]*hps.n_levels, eps_std=None):
94 | for i in reversed(range(hps.n_levels)):
95 | if i < hps.n_levels-1:
96 | z = split2d_reverse("pool"+str(i), z, eps=eps[i], eps_std=eps_std)
97 | z, _ = revnet2d(str(i), z, 0, hps, reverse=True)
98 |
99 | return z
100 |
101 | return encoder, decoder
102 |
103 |
104 | def prior(name, y_onehot, hps):
105 |
106 | with tf.variable_scope(name):
107 | n_z = hps.top_shape[-1]
108 |
109 | h = tf.zeros([tf.shape(y_onehot)[0]]+hps.top_shape[:2]+[2*n_z])
110 | if hps.learntop:
111 | h = Z.conv2d_zeros('p', h, 2*n_z)
112 | if hps.ycond:
113 | h += tf.reshape(Z.linear_zeros("y_emb", y_onehot,
114 | 2*n_z), [-1, 1, 1, 2 * n_z])
115 |
116 | pz = Z.gaussian_diag(h[:, :, :, :n_z], h[:, :, :, n_z:])
117 |
118 | def logp(z1):
119 | objective = pz.logp(z1)
120 | return objective
121 |
122 | def sample(eps=None, eps_std=None):
123 | if eps is not None:
124 | # Already sampled eps. Don't use eps_std
125 | z = pz.sample2(eps)
126 | elif eps_std is not None:
127 | # Sample with given eps_std
128 | z = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1]))
129 | else:
130 | # Sample normally
131 | z = pz.sample
132 |
133 | return z
134 |
135 | def eps(z1):
136 | return pz.get_eps(z1)
137 |
138 | return logp, sample, eps
139 |
140 |
141 | def model(sess, hps, train_iterator, test_iterator, data_init):
142 |
143 | # Only for decoding/init, rest use iterators directly
144 | with tf.name_scope('input'):
145 | X = tf.placeholder(
146 | tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image')
147 | Y = tf.placeholder(tf.int32, [None], name='label')
148 | lr = tf.placeholder(tf.float32, None, name='learning_rate')
149 |
150 | encoder, decoder = codec(hps)
151 | hps.n_bins = 2. ** hps.n_bits_x
152 |
153 | def preprocess(x):
154 | x = tf.cast(x, 'float32')
155 | if hps.n_bits_x < 8:
156 | x = tf.floor(x / 2 ** (8 - hps.n_bits_x))
157 | x = x / hps.n_bins - .5
158 | return x
159 |
160 | def postprocess(x):
161 | return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8')
162 |
163 | def _f_loss(x, y, is_training, reuse=False):
164 |
165 | with tf.variable_scope('model', reuse=reuse):
166 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
167 |
168 | # Discrete -> Continuous
169 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
170 | z = preprocess(x)
171 | z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins)
172 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
173 |
174 | # Encode
175 | z = Z.squeeze2d(z, 2) # > 16x16x12
176 | z, objective, _ = encoder(z, objective)
177 |
178 | # Prior
179 | hps.top_shape = Z.int_shape(z)[1:]
180 | logp, _, _ = prior("prior", y_onehot, hps)
181 | objective += logp(z)
182 |
183 | # Generative loss
184 | nobj = - objective
185 | bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
186 | x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel
187 |
188 | # Predictive loss
189 | if hps.weight_y > 0 and hps.ycond:
190 |
191 | # Classification loss
192 | h_y = tf.reduce_mean(z, axis=[1, 2])
193 | y_logits = Z.linear_zeros("classifier", h_y, hps.n_y)
194 | bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
195 | labels=y_onehot, logits=y_logits) / np.log(2.)
196 |
197 | # Classification accuracy
198 | y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
199 | classification_error = 1 - \
200 | tf.cast(tf.equal(y_predicted, y), tf.float32)
201 | else:
202 | bits_y = tf.zeros_like(bits_x)
203 | classification_error = tf.ones_like(bits_x)
204 |
205 | return bits_x, bits_y, classification_error
206 |
207 | def f_loss(iterator, is_training, reuse=False):
208 | if hps.direct_iterator and iterator is not None:
209 | x, y = iterator.get_next()
210 | else:
211 | x, y = X, Y
212 |
213 | bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse)
214 | local_loss = bits_x + hps.weight_y * bits_y
215 | stats = [local_loss, bits_x, bits_y, pred_loss]
216 | global_stats = Z.allreduce_mean(
217 | tf.stack([tf.reduce_mean(i) for i in stats]))
218 |
219 | return tf.reduce_mean(local_loss), global_stats
220 |
221 | feeds = {'x': X, 'y': Y}
222 | m = abstract_model_xy(sess, hps, feeds, train_iterator,
223 | test_iterator, data_init, lr, f_loss)
224 |
225 | # === Sampling function
226 | def f_sample(y, eps_std):
227 | with tf.variable_scope('model', reuse=True):
228 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
229 |
230 | _, sample, _ = prior("prior", y_onehot, hps)
231 | z = sample(eps_std=eps_std)
232 | z = decoder(z, eps_std=eps_std)
233 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
234 | x = postprocess(z)
235 |
236 | return x
237 |
238 | m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std')
239 | x_sampled = f_sample(Y, m.eps_std)
240 |
241 | def sample(_y, _eps_std):
242 | return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std})
243 | m.sample = sample
244 |
245 | if hps.inference:
246 | # === Encoder-Decoder functions
247 | def f_encode(x, y, reuse=True):
248 | with tf.variable_scope('model', reuse=reuse):
249 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
250 |
251 | # Discrete -> Continuous
252 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
253 | z = preprocess(x)
254 | z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
255 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
256 |
257 | # Encode
258 | z = Z.squeeze2d(z, 2) # > 16x16x12
259 | z, objective, eps = encoder(z, objective)
260 |
261 | # Prior
262 | hps.top_shape = Z.int_shape(z)[1:]
263 | logp, _, _eps = prior("prior", y_onehot, hps)
264 | objective += logp(z)
265 | eps.append(_eps(z))
266 |
267 | return eps
268 |
269 | def f_decode(y, eps, reuse=True):
270 | with tf.variable_scope('model', reuse=reuse):
271 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
272 |
273 | _, sample, _ = prior("prior", y_onehot, hps)
274 | z = sample(eps=eps[-1])
275 | z = decoder(z, eps=eps[:-1])
276 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
277 | x = postprocess(z)
278 |
279 | return x
280 |
281 | enc_eps = f_encode(X, Y)
282 | dec_eps = []
283 | print(enc_eps)
284 | for i, _eps in enumerate(enc_eps):
285 | print(_eps)
286 | dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i)))
287 | dec_x = f_decode(Y, dec_eps)
288 |
289 | eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps]
290 |
291 | def flatten_eps(eps):
292 | # [BS, eps_size]
293 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1)
294 |
295 | def unflatten_eps(feps):
296 | index = 0
297 | eps = []
298 | bs = feps.shape[0]
299 | for shape in eps_shapes:
300 | eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape)))
301 | index += np.prod(shape)
302 | return eps
303 |
304 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
305 | def encode(x, y):
306 | return flatten_eps(sess.run(enc_eps, {X: x, Y: y}))
307 |
308 | def decode(y, feps):
309 | eps = unflatten_eps(feps)
310 | feed_dict = {Y: y}
311 | for i in range(len(dec_eps)):
312 | feed_dict[dec_eps[i]] = eps[i]
313 | return sess.run(dec_x, feed_dict)
314 |
315 | m.encode = encode
316 | m.decode = decode
317 |
318 | return m
319 |
320 |
321 | def checkpoint(z, logdet):
322 | zshape = Z.int_shape(z)
323 | z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]])
324 | logdet = tf.reshape(logdet, [-1, 1])
325 | combined = tf.concat([z, logdet], axis=1)
326 | tf.add_to_collection('checkpoints', combined)
327 | logdet = combined[:, -1]
328 | z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]])
329 | return z, logdet
330 |
331 |
332 | @add_arg_scope
333 | def revnet2d(name, z, logdet, hps, reverse=False):
334 | with tf.variable_scope(name):
335 | if not reverse:
336 | for i in range(hps.depth):
337 | z, logdet = checkpoint(z, logdet)
338 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse)
339 | z, logdet = checkpoint(z, logdet)
340 | else:
341 | for i in reversed(range(hps.depth)):
342 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse)
343 | return z, logdet
344 |
345 | # Simpler, new version
346 | @add_arg_scope
347 | def revnet2d_step(name, z, logdet, hps, reverse):
348 | with tf.variable_scope(name):
349 |
350 | shape = Z.int_shape(z)
351 | n_z = shape[3]
352 | assert n_z % 2 == 0
353 |
354 | if not reverse:
355 |
356 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet)
357 |
358 | if hps.flow_permutation == 0:
359 | z = Z.reverse_features("reverse", z)
360 | elif hps.flow_permutation == 1:
361 | z = Z.shuffle_features("shuffle", z)
362 | elif hps.flow_permutation == 2:
363 | z, logdet = invertible_1x1_conv("invconv", z, logdet)
364 | else:
365 | raise Exception()
366 |
367 | z1 = z[:, :, :, :n_z // 2]
368 | z2 = z[:, :, :, n_z // 2:]
369 |
370 | if hps.flow_coupling == 0:
371 | z2 += f("f1", z1, hps.width)
372 | elif hps.flow_coupling == 1:
373 | h = f("f1", z1, hps.width, n_z)
374 | shift = h[:, :, :, 0::2]
375 | # scale = tf.exp(h[:, :, :, 1::2])
376 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
377 | z2 += shift
378 | z2 *= scale
379 | logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
380 | else:
381 | raise Exception()
382 |
383 | z = tf.concat([z1, z2], 3)
384 |
385 | else:
386 |
387 | z1 = z[:, :, :, :n_z // 2]
388 | z2 = z[:, :, :, n_z // 2:]
389 |
390 | if hps.flow_coupling == 0:
391 | z2 -= f("f1", z1, hps.width)
392 | elif hps.flow_coupling == 1:
393 | h = f("f1", z1, hps.width, n_z)
394 | shift = h[:, :, :, 0::2]
395 | # scale = tf.exp(h[:, :, :, 1::2])
396 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
397 | z2 /= scale
398 | z2 -= shift
399 | logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
400 | else:
401 | raise Exception()
402 |
403 | z = tf.concat([z1, z2], 3)
404 |
405 | if hps.flow_permutation == 0:
406 | z = Z.reverse_features("reverse", z, reverse=True)
407 | elif hps.flow_permutation == 1:
408 | z = Z.shuffle_features("shuffle", z, reverse=True)
409 | elif hps.flow_permutation == 2:
410 | z, logdet = invertible_1x1_conv(
411 | "invconv", z, logdet, reverse=True)
412 | else:
413 | raise Exception()
414 |
415 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)
416 |
417 | return z, logdet
418 |
419 |
420 | def f(name, h, width, n_out=None):
421 | n_out = n_out or int(h.get_shape()[3])
422 | with tf.variable_scope(name):
423 | h = tf.nn.relu(Z.conv2d("l_1", h, width))
424 | h = tf.nn.relu(Z.conv2d("l_2", h, width, filter_size=[1, 1]))
425 | h = Z.conv2d_zeros("l_last", h, n_out)
426 | return h
427 |
428 |
429 | def f_resnet(name, h, width, n_out=None):
430 | n_out = n_out or int(h.get_shape()[3])
431 | with tf.variable_scope(name):
432 | h = tf.nn.relu(Z.conv2d("l_1", h, width))
433 | h = Z.conv2d_zeros("l_2", h, n_out)
434 | return h
435 |
436 | # Invertible 1x1 conv
437 | @add_arg_scope
438 | def invertible_1x1_conv(name, z, logdet, reverse=False):
439 |
440 | if True: # Set to "False" to use the LU-decomposed version
441 |
442 | with tf.variable_scope(name):
443 |
444 | shape = Z.int_shape(z)
445 | w_shape = [shape[3], shape[3]]
446 |
447 | # Sample a random orthogonal matrix:
448 | w_init = np.linalg.qr(np.random.randn(
449 | *w_shape))[0].astype('float32')
450 |
451 | w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
452 |
453 | # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
454 | dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
455 | tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]
456 |
457 | if not reverse:
458 |
459 | _w = tf.reshape(w, [1, 1] + w_shape)
460 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
461 | 'SAME', data_format='NHWC')
462 | logdet += dlogdet
463 |
464 | return z, logdet
465 | else:
466 |
467 | _w = tf.matrix_inverse(w)
468 | _w = tf.reshape(_w, [1, 1]+w_shape)
469 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
470 | 'SAME', data_format='NHWC')
471 | logdet -= dlogdet
472 |
473 | return z, logdet
474 |
475 | else:
476 |
477 | # LU-decomposed version
478 | shape = Z.int_shape(z)
479 | with tf.variable_scope(name):
480 |
481 | dtype = 'float64'
482 |
483 | # Random orthogonal matrix:
484 | import scipy
485 | np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[
486 | 0].astype('float32')
487 |
488 | np_p, np_l, np_u = scipy.linalg.lu(np_w)
489 | np_s = np.diag(np_u)
490 | np_sign_s = np.sign(np_s)
491 | np_log_s = np.log(abs(np_s))
492 | np_u = np.triu(np_u, k=1)
493 |
494 | p = tf.get_variable("P", initializer=np_p, trainable=False)
495 | l = tf.get_variable("L", initializer=np_l)
496 | sign_s = tf.get_variable(
497 | "sign_S", initializer=np_sign_s, trainable=False)
498 | log_s = tf.get_variable("log_S", initializer=np_log_s)
499 | # S = tf.get_variable("S", initializer=np_s)
500 | u = tf.get_variable("U", initializer=np_u)
501 |
502 | p = tf.cast(p, dtype)
503 | l = tf.cast(l, dtype)
504 | sign_s = tf.cast(sign_s, dtype)
505 | log_s = tf.cast(log_s, dtype)
506 | u = tf.cast(u, dtype)
507 |
508 | w_shape = [shape[3], shape[3]]
509 |
510 | l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
511 | l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
512 | u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
513 | w = tf.matmul(p, tf.matmul(l, u))
514 |
515 | if True:
516 | u_inv = tf.matrix_inverse(u)
517 | l_inv = tf.matrix_inverse(l)
518 | p_inv = tf.matrix_inverse(p)
519 | w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
520 | else:
521 | w_inv = tf.matrix_inverse(w)
522 |
523 | w = tf.cast(w, tf.float32)
524 | w_inv = tf.cast(w_inv, tf.float32)
525 | log_s = tf.cast(log_s, tf.float32)
526 |
527 | if not reverse:
528 |
529 | w = tf.reshape(w, [1, 1] + w_shape)
530 | z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
531 | 'SAME', data_format='NHWC')
532 | logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])
533 |
534 | return z, logdet
535 | else:
536 |
537 | w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
538 | z = tf.nn.conv2d(
539 | z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
540 | logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])
541 |
542 | return z, logdet
543 |
544 |
545 | @add_arg_scope
546 | def split2d(name, z, objective=0.):
547 | with tf.variable_scope(name):
548 | n_z = Z.int_shape(z)[3]
549 | z1 = z[:, :, :, :n_z // 2]
550 | z2 = z[:, :, :, n_z // 2:]
551 | pz = split2d_prior(z1)
552 | objective += pz.logp(z2)
553 | z1 = Z.squeeze2d(z1)
554 | eps = pz.get_eps(z2)
555 | return z1, objective, eps
556 |
557 |
558 | @add_arg_scope
559 | def split2d_reverse(name, z, eps, eps_std):
560 | with tf.variable_scope(name):
561 | z1 = Z.unsqueeze2d(z)
562 | pz = split2d_prior(z1)
563 | if eps is not None:
564 | # Already sampled eps
565 | z2 = pz.sample2(eps)
566 | elif eps_std is not None:
567 | # Sample with given eps_std
568 | z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1]))
569 | else:
570 | # Sample normally
571 | z2 = pz.sample
572 | z = tf.concat([z1, z2], 3)
573 | return z
574 |
575 |
576 | @add_arg_scope
577 | def split2d_prior(z):
578 | n_z2 = int(z.get_shape()[3])
579 | n_z1 = n_z2
580 | h = Z.conv2d_zeros("conv", z, 2 * n_z1)
581 |
582 | mean = h[:, :, :, 0::2]
583 | logs = h[:, :, :, 1::2]
584 | return Z.gaussian_diag(mean, logs)
585 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tfops as Z
3 | import horovod.tensorflow as hvd
4 |
5 | # Optimizers
6 |
7 | '''
8 | Polyak averaging op
9 | '''
10 |
11 |
12 | def polyak(params, beta):
13 | #params = tf.trainable_variables()
14 | ema = tf.train.ExponentialMovingAverage(decay=beta, zero_debias=True)
15 | avg_op = tf.group(ema.apply(params))
16 | # Swapping op
17 | updates = []
18 | for i in range(len(params)):
19 | p = params[i]
20 | avg = ema.average(p)
21 | tmp = 0. + avg * 1.
22 | with tf.control_dependencies([tmp]):
23 | update1 = avg.assign(p)
24 | with tf.control_dependencies([update1]):
25 | update2 = p.assign(tmp)
26 | updates += [update1, update2]
27 | swap_op = tf.group(*updates)
28 | return avg_op, swap_op, ema
29 |
30 |
31 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
32 | updates = []
33 | if type(cost_or_grads) is not list:
34 | gs = tf.gradients(cost_or_grads, params)
35 | else:
36 | gs = cost_or_grads
37 |
38 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
39 |
40 | # all-reduce
41 | grads = [Z.allreduce_mean(g) for g in gs]
42 |
43 | t = tf.Variable(1., 'adam_t')
44 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
45 | (1. - tf.pow(hps.beta1, t))
46 | updates.append(t.assign_add(1))
47 |
48 | for w, g in zip(params, grads):
49 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
50 | if hps.beta1 > 0:
51 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
52 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
53 | updates.append(mom1.assign(mom1_new))
54 | else:
55 | mom1_new = g
56 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g)
57 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
58 | w_new = hps.weight_decay * w - alpha_t * delta_t
59 | updates.append(mom2.assign(m2_new))
60 | updates.append(w.assign(w_new))
61 |
62 | # Polyak averaging
63 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
64 | train_op = tf.group(polyak_avg_op, *updates)
65 | return train_op, polyak_swap_op, ema
66 |
67 |
68 | '''
69 | Adam optimizer
70 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum).
71 | (It doesn't seem to work yet, though.)
72 | '''
73 |
74 |
75 | def adam2(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
76 | updates = []
77 | if type(cost_or_grads) is not list:
78 | gs = tf.gradients(cost_or_grads, params)
79 | else:
80 | gs = cost_or_grads
81 |
82 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
83 |
84 | # all-reduce
85 | grads1 = [Z.allreduce_mean(g) for g in gs]
86 | grads2 = [Z.allreduce_mean(g**2) for g in gs]
87 |
88 | t = tf.Variable(1., 'adam_t')
89 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
90 | (1. - tf.pow(hps.beta1, t))
91 | updates.append(t.assign_add(1))
92 |
93 | for w, g1, g2 in zip(params, grads1, grads2):
94 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
95 | if hps.beta1 > 0:
96 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
97 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g1
98 | updates.append(mom1.assign(mom1_new))
99 | else:
100 | mom1_new = g1
101 | m2_new = beta2 * mom2 + (1. - beta2) * g2
102 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
103 | w_new = hps.weight_decay * w - alpha_t * delta_t
104 | updates.append(mom2.assign(m2_new))
105 | updates.append(w.assign(w_new))
106 |
107 | # Polyak averaging
108 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
109 | train_op = tf.group(polyak_avg_op, *updates)
110 | return train_op, polyak_swap_op, ema
111 |
112 |
113 | '''
114 | Adam optimizer
115 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum).
116 | It doesn't seem to work though.
117 | '''
118 |
119 |
120 | def adam2_old(params, cost_or_grads, lr=3e-4, mom1=0.9, mom2=0.999, epsilon=1e-8):
121 | updates = []
122 | if type(cost_or_grads) is not list:
123 | gs = tf.gradients(cost_or_grads, params)
124 | else:
125 | gs = cost_or_grads
126 |
127 | # all-reduce
128 | grads1 = [Z.allreduce_mean(g) for g in gs]
129 | grads2 = [Z.allreduce_mean(tf.square(g)) for g in gs]
130 | mom2 = tf.maximum(0., 1. - (hvd.size() * (1 - mom2)))
131 |
132 | t = tf.Variable(1., 'adam_t')
133 | lr_t = lr * tf.sqrt((1. - tf.pow(mom2, t))) / (1. - tf.pow(mom1, t))
134 | updates.append(t.assign_add(1))
135 |
136 | for p, g1, g2 in zip(params, grads1, grads2):
137 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')
138 | if mom1 > 0:
139 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
140 | v_t = mom1 * v + (1. - mom1) * g1
141 | updates.append(v.assign(v_t))
142 | else:
143 | v_t = g1
144 | mg_t = mom2 * mg + (1. - mom2) * g2
145 | delta_t = v_t / (tf.sqrt(mg_t) + epsilon)
146 | p_t = p - lr_t * delta_t
147 | updates.append(mg.assign(mg_t))
148 | updates.append(p.assign(p_t))
149 | return tf.group(*updates)
150 |
151 |
152 | def adamax(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
153 | updates = []
154 | if type(cost_or_grads) is not list:
155 | gs = tf.gradients(cost_or_grads, params)
156 | else:
157 | gs = cost_or_grads
158 |
159 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
160 |
161 | # all-reduce
162 | grads = [Z.allreduce_mean(g) for g in gs]
163 |
164 | t = tf.Variable(1., 'adam_t')
165 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
166 | (1. - tf.pow(hps.beta1, t))
167 | updates.append(t.assign_add(1))
168 |
169 | for w, g in zip(params, grads):
170 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
171 | if hps.beta1 > 0:
172 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
173 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
174 | updates.append(mom1.assign(mom1_new))
175 | else:
176 | mom1_new = g
177 | m2_new = tf.maximum(beta2 * mom2, abs(g))
178 | delta_t = mom1_new / (m2_new + epsilon)
179 | w_new = hps.weight_decay * w - alpha_t * delta_t
180 | updates.append(mom2.assign(m2_new))
181 | updates.append(w.assign(w_new))
182 |
183 | # Polyak averaging
184 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
185 | train_op = tf.group(polyak_avg_op, *updates)
186 | return train_op, polyak_swap_op, ema
187 |
188 |
189 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
190 | updates = []
191 | if type(cost_or_grads) is not list:
192 | gs = tf.gradients(cost_or_grads, params)
193 | else:
194 | gs = cost_or_grads
195 |
196 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
197 |
198 | # all-reduce
199 | grads = [Z.allreduce_mean(g) for g in gs]
200 |
201 | t = tf.Variable(1., 'adam_t')
202 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
203 | (1. - tf.pow(hps.beta1, t))
204 | updates.append(t.assign_add(1))
205 |
206 | for w, g in zip(params, grads):
207 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
208 | if hps.beta1 > 0:
209 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
210 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
211 | updates.append(mom1.assign(mom1_new))
212 | else:
213 | mom1_new = g
214 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g)
215 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
216 | w_new = hps.weight_decay * w - alpha_t * delta_t
217 | updates.append(mom2.assign(m2_new))
218 | updates.append(w.assign(w_new))
219 |
220 | # Polyak averaging
221 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
222 | train_op = tf.group(polyak_avg_op, *updates)
223 | return train_op, polyak_swap_op, ema
224 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu==1.8.0
2 | keras==2.2.0
3 | pillow==5.2.0
4 | toposort==1.5
5 | horovod==0.13.8
6 |
--------------------------------------------------------------------------------
/tfops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
3 | from tensorflow.contrib.layers import variance_scaling_initializer
4 | import numpy as np
5 | import horovod.tensorflow as hvd
6 |
7 | # Debugging function
8 | do_print_act_stats = True
9 |
10 |
11 | def print_act_stats(x, _str=""):
12 | if not do_print_act_stats:
13 | return x
14 | if hvd.rank() != 0:
15 | return x
16 | if len(x.get_shape()) == 1:
17 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
18 | if len(x.get_shape()) == 2:
19 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
20 | if len(x.get_shape()) == 4:
21 | x_mean, x_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
22 | stats = [tf.reduce_min(x_mean), tf.reduce_mean(x_mean), tf.reduce_max(x_mean),
23 | tf.reduce_min(tf.sqrt(x_var)), tf.reduce_mean(tf.sqrt(x_var)), tf.reduce_max(tf.sqrt(x_var))]
24 | return tf.Print(x, stats, "["+_str+"] "+x.name)
25 |
26 | # Allreduce methods
27 |
28 |
29 | def allreduce_sum(x):
30 | if hvd.size() == 1:
31 | return x
32 | return hvd.mpi_ops._allreduce(x)
33 |
34 |
35 | def allreduce_mean(x):
36 | x = allreduce_sum(x) / hvd.size()
37 | return x
38 |
39 |
40 | def default_initial_value(shape, std=0.05):
41 | return tf.random_normal(shape, 0., std)
42 |
43 |
44 | def default_initializer(std=0.05):
45 | return tf.random_normal_initializer(0., std)
46 |
47 |
48 | def int_shape(x):
49 | if str(x.get_shape()[0]) != '?':
50 | return list(map(int, x.get_shape()))
51 | return [-1]+list(map(int, x.get_shape()[1:]))
52 |
53 | # wrapper tf.get_variable, augmented with 'init' functionality
54 | # Get variable with data dependent init
55 |
56 |
57 | @add_arg_scope
58 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, trainable=True):
59 | w = tf.get_variable(name, shape, dtype, None, trainable=trainable)
60 | if init:
61 | w = w.assign(initial_value)
62 | with tf.control_dependencies([w]):
63 | return w
64 | return w
65 |
66 | # Activation normalization
67 | # Convenience function that does centering+scaling
68 |
69 |
70 | @add_arg_scope
71 | def actnorm(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
72 | if arg_scope([get_variable_ddi], trainable=trainable):
73 | if not reverse:
74 | x = actnorm_center(name+"_center", x, reverse)
75 | x = actnorm_scale(name+"_scale", x, scale, logdet,
76 | logscale_factor, batch_variance, reverse, init)
77 | if logdet != None:
78 | x, logdet = x
79 | else:
80 | x = actnorm_scale(name + "_scale", x, scale, logdet,
81 | logscale_factor, batch_variance, reverse, init)
82 | if logdet != None:
83 | x, logdet = x
84 | x = actnorm_center(name+"_center", x, reverse)
85 | if logdet != None:
86 | return x, logdet
87 | return x
88 |
89 | # Activation normalization
90 |
91 |
92 | @add_arg_scope
93 | def actnorm_center(name, x, reverse=False):
94 | shape = x.get_shape()
95 | with tf.variable_scope(name):
96 | assert len(shape) == 2 or len(shape) == 4
97 | if len(shape) == 2:
98 | x_mean = tf.reduce_mean(x, [0], keepdims=True)
99 | b = get_variable_ddi(
100 | "b", (1, int_shape(x)[1]), initial_value=-x_mean)
101 | elif len(shape) == 4:
102 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True)
103 | b = get_variable_ddi(
104 | "b", (1, 1, 1, int_shape(x)[3]), initial_value=-x_mean)
105 |
106 | if not reverse:
107 | x += b
108 | else:
109 | x -= b
110 |
111 | return x
112 |
113 | # Activation normalization
114 |
115 |
116 | @add_arg_scope
117 | def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
118 | shape = x.get_shape()
119 | with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable):
120 | assert len(shape) == 2 or len(shape) == 4
121 | if len(shape) == 2:
122 | x_var = tf.reduce_mean(x**2, [0], keepdims=True)
123 | logdet_factor = 1
124 | _shape = (1, int_shape(x)[1])
125 |
126 | elif len(shape) == 4:
127 | x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True)
128 | logdet_factor = int(shape[1])*int(shape[2])
129 | _shape = (1, 1, 1, int_shape(x)[3])
130 |
131 | if batch_variance:
132 | x_var = tf.reduce_mean(x**2, keepdims=True)
133 |
134 | if init and False:
135 | # MPI all-reduce
136 | x_var = allreduce_mean(x_var)
137 | # Somehow this also slows down graph when not initializing
138 | # (it's not optimized away?)
139 |
140 | if True:
141 | logs = get_variable_ddi("logs", _shape, initial_value=tf.log(
142 | scale/(tf.sqrt(x_var)+1e-6))/logscale_factor)*logscale_factor
143 | if not reverse:
144 | x = x * tf.exp(logs)
145 | else:
146 | x = x * tf.exp(-logs)
147 | else:
148 | # Alternative, doesn't seem to do significantly worse or better than the logarithmic version above
149 | s = get_variable_ddi("s", _shape, initial_value=scale /
150 | (tf.sqrt(x_var) + 1e-6) / logscale_factor)*logscale_factor
151 | logs = tf.log(tf.abs(s))
152 | if not reverse:
153 | x *= s
154 | else:
155 | x /= s
156 |
157 | if logdet != None:
158 | dlogdet = tf.reduce_sum(logs) * logdet_factor
159 | if reverse:
160 | dlogdet *= -1
161 | return x, logdet + dlogdet
162 |
163 | return x
164 |
165 | # Linear layer with layer norm
166 |
167 |
168 | @add_arg_scope
169 | def linear(name, x, width, do_weightnorm=True, do_actnorm=True, initializer=None, scale=1.):
170 | initializer = initializer or default_initializer()
171 | with tf.variable_scope(name):
172 | n_in = int(x.get_shape()[1])
173 | w = tf.get_variable("W", [n_in, width],
174 | tf.float32, initializer=initializer)
175 | if do_weightnorm:
176 | w = tf.nn.l2_normalize(w, [0])
177 | x = tf.matmul(x, w)
178 | x += tf.get_variable("b", [1, width],
179 | initializer=tf.zeros_initializer())
180 | if do_actnorm:
181 | x = actnorm("actnorm", x, scale)
182 | return x
183 |
184 | # Linear layer with zero init
185 |
186 |
187 | @add_arg_scope
188 | def linear_zeros(name, x, width, logscale_factor=3):
189 | with tf.variable_scope(name):
190 | n_in = int(x.get_shape()[1])
191 | w = tf.get_variable("W", [n_in, width], tf.float32,
192 | initializer=tf.zeros_initializer())
193 | x = tf.matmul(x, w)
194 | x += tf.get_variable("b", [1, width],
195 | initializer=tf.zeros_initializer())
196 | x *= tf.exp(tf.get_variable("logs",
197 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor)
198 | return x
199 |
200 | # Slow way to add edge padding
201 |
202 |
203 | def add_edge_padding(x, filter_size):
204 | assert filter_size[0] % 2 == 1
205 | if filter_size[0] == 1 and filter_size[1] == 1:
206 | return x
207 | a = (filter_size[0] - 1) // 2 # vertical padding size
208 | b = (filter_size[1] - 1) // 2 # horizontal padding size
209 | if True:
210 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
211 | name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]])
212 | pads = tf.get_collection(name)
213 | if not pads:
214 | if hvd.rank() == 0:
215 | print("Creating pad", name)
216 | pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32')
217 | pad[:, :a, :, 0] = 1.
218 | pad[:, -a:, :, 0] = 1.
219 | pad[:, :, :b, 0] = 1.
220 | pad[:, :, -b:, 0] = 1.
221 | pad = tf.convert_to_tensor(pad)
222 | tf.add_to_collection(name, pad)
223 | else:
224 | pad = pads[0]
225 | pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1])
226 | x = tf.concat([x, pad], axis=3)
227 | else:
228 | pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1,
229 | [[0, 0], [a, a], [b, b], [0, 0]]) + 1
230 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
231 | x = tf.concat([x, pad], axis=3)
232 | return x
233 |
234 |
235 | @add_arg_scope
236 | def conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, edge_bias=True):
237 | with tf.variable_scope(name):
238 | if edge_bias and pad == "SAME":
239 | x = add_edge_padding(x, filter_size)
240 | pad = 'VALID'
241 |
242 | n_in = int(x.get_shape()[3])
243 |
244 | stride_shape = [1] + stride + [1]
245 | filter_shape = filter_size + [n_in, width]
246 | w = tf.get_variable("W", filter_shape, tf.float32,
247 | initializer=default_initializer())
248 | if do_weightnorm:
249 | w = tf.nn.l2_normalize(w, [0, 1, 2])
250 | if skip == 1:
251 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
252 | else:
253 | assert stride[0] == 1 and stride[1] == 1
254 | x = tf.nn.atrous_conv2d(x, w, skip, pad)
255 | if do_actnorm:
256 | x = actnorm("actnorm", x)
257 | else:
258 | x += tf.get_variable("b", [1, 1, 1, width],
259 | initializer=tf.zeros_initializer())
260 |
261 | if context1d != None:
262 | x += tf.reshape(linear("context", context1d,
263 | width), [-1, 1, 1, width])
264 | return x
265 |
266 |
267 | @add_arg_scope
268 | def separable_conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], padding="SAME", do_actnorm=True, std=0.05):
269 | n_in = int(x.get_shape()[3])
270 | with tf.variable_scope(name):
271 | assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1
272 | strides = [1] + stride + [1]
273 | w1_shape = filter_size + [n_in, 1]
274 | w1_init = np.zeros(w1_shape, dtype='float32')
275 | w1_init[(filter_size[0]-1)//2, (filter_size[1]-1)//2, :,
276 | :] = 1. # initialize depthwise conv as identity
277 | w1 = tf.get_variable("W1", dtype=tf.float32, initializer=w1_init)
278 | w2_shape = [1, 1, n_in, width]
279 | w2 = tf.get_variable("W2", w2_shape, tf.float32,
280 | initializer=default_initializer(std))
281 | x = tf.nn.separable_conv2d(
282 | x, w1, w2, strides, padding, data_format='NHWC')
283 | if do_actnorm:
284 | x = actnorm("actnorm", x)
285 | else:
286 | x += tf.get_variable("b", [1, 1, 1, width],
287 | initializer=tf.zeros_initializer(std))
288 |
289 | return x
290 |
291 |
292 | @add_arg_scope
293 | def conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", logscale_factor=3, skip=1, edge_bias=True):
294 | with tf.variable_scope(name):
295 | if edge_bias and pad == "SAME":
296 | x = add_edge_padding(x, filter_size)
297 | pad = 'VALID'
298 |
299 | n_in = int(x.get_shape()[3])
300 | stride_shape = [1] + stride + [1]
301 | filter_shape = filter_size + [n_in, width]
302 | w = tf.get_variable("W", filter_shape, tf.float32,
303 | initializer=tf.zeros_initializer())
304 | if skip == 1:
305 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
306 | else:
307 | assert stride[0] == 1 and stride[1] == 1
308 | x = tf.nn.atrous_conv2d(x, w, skip, pad)
309 | x += tf.get_variable("b", [1, 1, 1, width],
310 | initializer=tf.zeros_initializer())
311 | x *= tf.exp(tf.get_variable("logs",
312 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor)
313 | return x
314 |
315 |
316 | # 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code
317 | def upsample2d_nearest_neighbour(x):
318 | shape = x.get_shape()
319 | n_batch = int(shape[0])
320 | height = int(shape[1])
321 | width = int(shape[2])
322 | n_channels = int(shape[3])
323 | x = tf.reshape(x, (n_batch, height, 1, width, 1, n_channels))
324 | x = tf.concat(2, [x, x])
325 | x = tf.concat(4, [x, x])
326 | x = tf.reshape(x, (n_batch, height*2, width*2, n_channels))
327 | return x
328 |
329 |
330 | def upsample(x, factor=2):
331 | shape = x.get_shape()
332 | height = int(shape[1])
333 | width = int(shape[2])
334 | x = tf.image.resize_nearest_neighbor(x, [height * factor, width * factor])
335 | return x
336 |
337 |
338 | def squeeze2d(x, factor=2):
339 | assert factor >= 1
340 | if factor == 1:
341 | return x
342 | shape = x.get_shape()
343 | height = int(shape[1])
344 | width = int(shape[2])
345 | n_channels = int(shape[3])
346 | assert height % factor == 0 and width % factor == 0
347 | x = tf.reshape(x, [-1, height//factor, factor,
348 | width//factor, factor, n_channels])
349 | x = tf.transpose(x, [0, 1, 3, 5, 2, 4])
350 | x = tf.reshape(x, [-1, height//factor, width //
351 | factor, n_channels*factor*factor])
352 | return x
353 |
354 |
355 | def unsqueeze2d(x, factor=2):
356 | assert factor >= 1
357 | if factor == 1:
358 | return x
359 | shape = x.get_shape()
360 | height = int(shape[1])
361 | width = int(shape[2])
362 | n_channels = int(shape[3])
363 | assert n_channels >= 4 and n_channels % 4 == 0
364 | x = tf.reshape(
365 | x, (-1, height, width, int(n_channels/factor**2), factor, factor))
366 | x = tf.transpose(x, [0, 1, 4, 2, 5, 3])
367 | x = tf.reshape(x, (-1, int(height*factor),
368 | int(width*factor), int(n_channels/factor**2)))
369 | return x
370 |
371 | # Reverse features across channel dimension
372 |
373 |
374 | def reverse_features(name, h, reverse=False):
375 | return h[:, :, :, ::-1]
376 |
377 | # Shuffle across the channel dimension
378 |
379 |
380 | def shuffle_features(name, h, indices=None, return_indices=False, reverse=False):
381 | with tf.variable_scope(name):
382 |
383 | rng = np.random.RandomState(
384 | (abs(hash(tf.get_variable_scope().name))) % 10000000)
385 |
386 | if indices == None:
387 | # Create numpy and tensorflow variables with indices
388 | n_channels = int(h.get_shape()[-1])
389 | indices = list(range(n_channels))
390 | rng.shuffle(indices)
391 | # Reverse it
392 | indices_inverse = [0]*n_channels
393 | for i in range(n_channels):
394 | indices_inverse[indices[i]] = i
395 |
396 | tf_indices = tf.get_variable("indices", dtype=tf.int32, initializer=np.asarray(
397 | indices, dtype='int32'), trainable=False)
398 | tf_indices_reverse = tf.get_variable("indices_inverse", dtype=tf.int32, initializer=np.asarray(
399 | indices_inverse, dtype='int32'), trainable=False)
400 |
401 | _indices = tf_indices
402 | if reverse:
403 | _indices = tf_indices_reverse
404 |
405 | if len(h.get_shape()) == 2:
406 | # Slice
407 | h = tf.transpose(h)
408 | h = tf.gather(h, _indices)
409 | h = tf.transpose(h)
410 | elif len(h.get_shape()) == 4:
411 | # Slice
412 | h = tf.transpose(h, [3, 1, 2, 0])
413 | h = tf.gather(h, _indices)
414 | h = tf.transpose(h, [3, 1, 2, 0])
415 | if return_indices:
416 | return h, indices
417 | return h
418 |
419 |
420 | def embedding(name, y, n_y, width):
421 | with tf.variable_scope(name):
422 | params = tf.get_variable(
423 | "embedding", [n_y, width], initializer=default_initializer())
424 | embeddings = tf.gather(params, y)
425 | return embeddings
426 |
427 | # Random variables
428 |
429 |
430 | def flatten_sum(logps):
431 | if len(logps.get_shape()) == 2:
432 | return tf.reduce_sum(logps, [1])
433 | elif len(logps.get_shape()) == 4:
434 | return tf.reduce_sum(logps, [1, 2, 3])
435 | else:
436 | raise Exception()
437 |
438 |
439 | def standard_gaussian(shape):
440 | return gaussian_diag(tf.zeros(shape), tf.zeros(shape))
441 |
442 |
443 | def gaussian_diag(mean, logsd):
444 | class o(object):
445 | pass
446 | o.mean = mean
447 | o.logsd = logsd
448 | o.eps = tf.random_normal(tf.shape(mean))
449 | o.sample = mean + tf.exp(logsd) * o.eps
450 | o.sample2 = lambda eps: mean + tf.exp(logsd) * eps
451 | o.logps = lambda x: -0.5 * \
452 | (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd))
453 | o.logp = lambda x: flatten_sum(o.logps(x))
454 | o.get_eps = lambda x: (x - mean) / tf.exp(logsd)
455 | return o
456 |
457 |
458 | # def discretized_logistic_old(mean, logscale, binsize=1 / 256.0, sample=None):
459 | # scale = tf.exp(logscale)
460 | # sample = (tf.floor(sample / binsize) * binsize - mean) / scale
461 | # logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7)
462 | # return tf.reduce_sum(logp, [1, 2, 3])
463 |
464 | def discretized_logistic(mean, logscale, binsize=1. / 256):
465 | class o(object):
466 | pass
467 | o.mean = mean
468 | o.logscale = logscale
469 | scale = tf.exp(logscale)
470 |
471 | def logps(x):
472 | x = (x - mean) / scale
473 | return tf.log(tf.sigmoid(x + binsize / scale) - tf.sigmoid(x) + 1e-7)
474 | o.logps = logps
475 | o.logp = lambda x: flatten_sum(logps(x))
476 | return o
477 |
478 |
479 | def _symmetric_matrix_square_root(mat, eps=1e-10):
480 | """Compute square root of a symmetric matrix.
481 | Note that this is different from an elementwise square root. We want to
482 | compute M' where M' = sqrt(mat) such that M' * M' = mat.
483 | Also note that this method **only** works for symmetric matrices.
484 | Args:
485 | mat: Matrix to take the square root of.
486 | eps: Small epsilon such that any element less than eps will not be square
487 | rooted to guard against numerical instability.
488 | Returns:
489 | Matrix square root of mat.
490 | """
491 | # Unlike numpy, tensorflow's return order is (s, u, v)
492 | s, u, v = tf.svd(mat)
493 | # sqrt is unstable around 0, just use 0 in such case
494 | si = tf.where(tf.less(s, eps), s, tf.sqrt(s))
495 | # Note that the v returned by Tensorflow is v = V
496 | # (when referencing the equation A = U S V^T)
497 | # This is unlike Numpy which returns v = V^T
498 | return tf.matmul(
499 | tf.matmul(u, tf.diag(si)), v, transpose_b=True)
500 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Modified Horovod MNIST example
4 |
5 | import os
6 | import sys
7 | import time
8 |
9 | import horovod.tensorflow as hvd
10 | import numpy as np
11 | import tensorflow as tf
12 | import graphics
13 | from utils import ResultLogger
14 |
15 | learn = tf.contrib.learn
16 |
17 | # Surpress verbose warnings
18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19 |
20 |
21 | def _print(*args, **kwargs):
22 | if hvd.rank() == 0:
23 | print(*args, **kwargs)
24 |
25 |
26 | def init_visualizations(hps, model, logdir):
27 |
28 | def sample_batch(y, eps):
29 | n_batch = hps.local_batch_train
30 | xs = []
31 | for i in range(int(np.ceil(len(eps) / n_batch))):
32 | xs.append(model.sample(
33 | y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch]))
34 | return np.concatenate(xs)
35 |
36 | def draw_samples(epoch):
37 | if hvd.rank() != 0:
38 | return
39 |
40 | rows = 10 if hps.image_size <= 64 else 4
41 | cols = rows
42 | n_batch = rows*cols
43 | y = np.asarray([_y % hps.n_y for _y in (
44 | list(range(cols)) * rows)], dtype='int32')
45 |
46 | # temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously
47 | temperatures = [0., .25, .5, .6, .7, .8, .9, 1.]
48 |
49 | x_samples = []
50 | x_samples.append(sample_batch(y, [.0]*n_batch))
51 | x_samples.append(sample_batch(y, [.25]*n_batch))
52 | x_samples.append(sample_batch(y, [.5]*n_batch))
53 | x_samples.append(sample_batch(y, [.6]*n_batch))
54 | x_samples.append(sample_batch(y, [.7]*n_batch))
55 | x_samples.append(sample_batch(y, [.8]*n_batch))
56 | x_samples.append(sample_batch(y, [.9] * n_batch))
57 | x_samples.append(sample_batch(y, [1.]*n_batch))
58 | # previously: 0, .25, .5, .625, .75, .875, 1.
59 |
60 | for i in range(len(x_samples)):
61 | x_sample = np.reshape(
62 | x_samples[i], (n_batch, hps.image_size, hps.image_size, 3))
63 | graphics.save_raster(x_sample, logdir +
64 | 'epoch_{}_sample_{}.png'.format(epoch, i))
65 |
66 | return draw_samples
67 |
68 | # ===
69 | # Code for getting data
70 | # ===
71 | def get_data(hps, sess):
72 | if hps.image_size == -1:
73 | hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64,
74 | 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem]
75 | if hps.n_test == -1:
76 | hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000,
77 | 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem]
78 | hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000,
79 | 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem]
80 | if hps.data_dir == "":
81 | hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr',
82 | 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem]
83 |
84 | if hps.problem == 'lsun_realnvp':
85 | hps.rnd_crop = True
86 | else:
87 | hps.rnd_crop = False
88 |
89 | if hps.category:
90 | hps.data_dir += ('/%s' % hps.category)
91 |
92 | # Use anchor_size to rescale batch size based on image_size
93 | s = hps.anchor_size
94 | hps.local_batch_train = hps.n_batch_train * \
95 | s * s // (hps.image_size * hps.image_size)
96 | hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[
97 | hps.local_batch_train] # round down to closest divisor of 50
98 | hps.local_batch_init = hps.n_batch_init * \
99 | s * s // (hps.image_size * hps.image_size)
100 |
101 | print("Rank {} Batch sizes Train {} Test {} Init {}".format(
102 | hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init))
103 |
104 | if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']:
105 | hps.direct_iterator = True
106 | import data_loaders.get_data as v
107 | train_iterator, test_iterator, data_init = \
108 | v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train,
109 | hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop)
110 |
111 | elif hps.problem in ['mnist', 'cifar10']:
112 | hps.direct_iterator = False
113 | import data_loaders.get_mnist_cifar as v
114 | train_iterator, test_iterator, data_init = \
115 | v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train,
116 | hps.local_batch_test, hps.local_batch_init, hps.image_size)
117 |
118 | else:
119 | raise Exception()
120 |
121 | return train_iterator, test_iterator, data_init
122 |
123 |
124 | def process_results(results):
125 | stats = ['loss', 'bits_x', 'bits_y', 'pred_loss']
126 | assert len(stats) == results.shape[0]
127 | res_dict = {}
128 | for i in range(len(stats)):
129 | res_dict[stats[i]] = "{:.4f}".format(results[i])
130 | return res_dict
131 |
132 |
133 | def main(hps):
134 |
135 | # Initialize Horovod.
136 | hvd.init()
137 |
138 | # Create tensorflow session
139 | sess = tensorflow_session()
140 |
141 | # Download and load dataset.
142 | tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed)
143 | np.random.seed(hvd.rank() + hvd.size() * hps.seed)
144 |
145 | # Get data and set train_its and valid_its
146 | train_iterator, test_iterator, data_init = get_data(hps, sess)
147 | hps.train_its, hps.test_its, hps.full_test_its = get_its(hps)
148 |
149 | # Create log dir
150 | logdir = os.path.abspath(hps.logdir) + "/"
151 | if not os.path.exists(logdir):
152 | os.mkdir(logdir)
153 |
154 | # Create model
155 | import model
156 | model = model.model(sess, hps, train_iterator, test_iterator, data_init)
157 |
158 | # Initialize visualization functions
159 | visualise = init_visualizations(hps, model, logdir)
160 |
161 | if not hps.inference:
162 | # Perform training
163 | train(sess, model, hps, logdir, visualise)
164 | else:
165 | infer(sess, model, hps, test_iterator)
166 |
167 |
168 | def infer(sess, model, hps, iterator):
169 | # Example of using model in inference mode. Load saved model using hps.restore_path
170 | # Can provide x, y from files instead of dataset iterator
171 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
172 | if hps.direct_iterator:
173 | iterator = iterator.get_next()
174 |
175 | xs = []
176 | zs = []
177 | for it in range(hps.full_test_its):
178 | if hps.direct_iterator:
179 | # replace with x, y, attr if you're getting CelebA attributes, also modify get_data
180 | x, y = sess.run(iterator)
181 | else:
182 | x, y = iterator()
183 |
184 | z = model.encode(x, y)
185 | x = model.decode(y, z)
186 | xs.append(x)
187 | zs.append(z)
188 |
189 | x = np.concatenate(xs, axis=0)
190 | z = np.concatenate(zs, axis=0)
191 | np.save('logs/x.npy', x)
192 | np.save('logs/z.npy', z)
193 | return zs
194 |
195 |
196 | def train(sess, model, hps, logdir, visualise):
197 | _print(hps)
198 | _print('Starting training. Logging to', logdir)
199 | _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')
200 |
201 | # Train
202 | sess.graph.finalize()
203 | n_processed = 0
204 | n_images = 0
205 | train_time = 0.0
206 | test_loss_best = 999999
207 |
208 | if hvd.rank() == 0:
209 | train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
210 | test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)
211 |
212 | tcurr = time.time()
213 | for epoch in range(1, hps.epochs):
214 |
215 | t = time.time()
216 |
217 | train_results = []
218 | for it in range(hps.train_its):
219 |
220 | # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
221 | lr = hps.lr * min(1., n_processed /
222 | (hps.n_train * hps.epochs_warmup))
223 |
224 | # Run a training step synchronously.
225 | _t = time.time()
226 | train_results += [model.train(lr)]
227 | if hps.verbose and hvd.rank() == 0:
228 | _print(n_processed, time.time()-_t, train_results[-1])
229 | sys.stdout.flush()
230 |
231 | # Images seen wrt anchor resolution
232 | n_processed += hvd.size() * hps.n_batch_train
233 | # Actual images seen at current resolution
234 | n_images += hvd.size() * hps.local_batch_train
235 |
236 | train_results = np.mean(np.asarray(train_results), axis=0)
237 |
238 | dtrain = time.time() - t
239 | ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
240 | train_time += dtrain
241 |
242 | if hvd.rank() == 0:
243 | train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
244 | train_time), **process_results(train_results))
245 |
246 | if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
247 | test_results = []
248 | msg = ''
249 |
250 | t = time.time()
251 | # model.polyak_swap()
252 |
253 | if epoch % hps.epochs_full_valid == 0:
254 | # Full validation run
255 | for it in range(hps.full_test_its):
256 | test_results += [model.test()]
257 | test_results = np.mean(np.asarray(test_results), axis=0)
258 |
259 | if hvd.rank() == 0:
260 | test_logger.log(epoch=epoch, n_processed=n_processed,
261 | n_images=n_images, **process_results(test_results))
262 |
263 | # Save checkpoint
264 | if test_results[0] < test_loss_best:
265 | test_loss_best = test_results[0]
266 | model.save(logdir+"model_best_loss.ckpt")
267 | msg += ' *'
268 |
269 | dtest = time.time() - t
270 |
271 | # Sample
272 | t = time.time()
273 | if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
274 | visualise(epoch)
275 | dsample = time.time() - t
276 |
277 | if hvd.rank() == 0:
278 | dcurr = time.time() - tcurr
279 | tcurr = time.time()
280 | _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
281 | ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)
282 |
283 | # model.polyak_swap()
284 |
285 | if hvd.rank() == 0:
286 | _print("Finished!")
287 |
288 | # Get number of training and validation iterations
289 | def get_its(hps):
290 | # These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples
291 | train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size())))
292 | test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size())))
293 | train_epoch = train_its * hps.n_batch_train * hvd.size()
294 |
295 | # Do a full validation run
296 | if hvd.rank() == 0:
297 | print(hps.n_test, hps.local_batch_test, hvd.size())
298 | assert hps.n_test % (hps.local_batch_test * hvd.size()) == 0
299 | full_test_its = hps.n_test // (hps.local_batch_test * hvd.size())
300 |
301 | if hvd.rank() == 0:
302 | print("Train epoch size: " + str(train_epoch))
303 | return train_its, test_its, full_test_its
304 |
305 |
306 | '''
307 | Create tensorflow session with horovod
308 | '''
309 | def tensorflow_session():
310 | # Init session and params
311 | config = tf.ConfigProto()
312 | config.gpu_options.allow_growth = True
313 | # Pin GPU to local rank (one GPU per process)
314 | config.gpu_options.visible_device_list = str(hvd.local_rank())
315 | sess = tf.Session(config=config)
316 | return sess
317 |
318 |
319 | if __name__ == "__main__":
320 |
321 | # This enables a ctr-C without triggering errors
322 | import signal
323 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
324 |
325 | import argparse
326 | parser = argparse.ArgumentParser()
327 | parser.add_argument("--verbose", action='store_true', help="Verbose mode")
328 | parser.add_argument("--restore_path", type=str, default='',
329 | help="Location of checkpoint to restore")
330 | parser.add_argument("--inference", action="store_true",
331 | help="Use in inference mode")
332 | parser.add_argument("--logdir", type=str,
333 | default='./logs', help="Location to save logs")
334 |
335 | # Dataset hyperparams:
336 | parser.add_argument("--problem", type=str, default='cifar10',
337 | help="Problem (mnist/cifar10/imagenet")
338 | parser.add_argument("--category", type=str,
339 | default='', help="LSUN category")
340 | parser.add_argument("--data_dir", type=str, default='',
341 | help="Location of data")
342 | parser.add_argument("--dal", type=int, default=1,
343 | help="Data augmentation level: 0=None, 1=Standard, 2=Extra")
344 |
345 | # New dataloader params
346 | parser.add_argument("--fmap", type=int, default=1,
347 | help="# Threads for parallel file reading")
348 | parser.add_argument("--pmap", type=int, default=16,
349 | help="# Threads for parallel map")
350 |
351 | # Optimization hyperparams:
352 | parser.add_argument("--n_train", type=int,
353 | default=50000, help="Train epoch size")
354 | parser.add_argument("--n_test", type=int, default=-
355 | 1, help="Valid epoch size")
356 | parser.add_argument("--n_batch_train", type=int,
357 | default=64, help="Minibatch size")
358 | parser.add_argument("--n_batch_test", type=int,
359 | default=50, help="Minibatch size")
360 | parser.add_argument("--n_batch_init", type=int, default=256,
361 | help="Minibatch size for data-dependent init")
362 | parser.add_argument("--optimizer", type=str,
363 | default="adamax", help="adam or adamax")
364 | parser.add_argument("--lr", type=float, default=0.001,
365 | help="Base learning rate")
366 | parser.add_argument("--beta1", type=float, default=.9, help="Adam beta1")
367 | parser.add_argument("--polyak_epochs", type=float, default=1,
368 | help="Nr of averaging epochs for Polyak and beta2")
369 | parser.add_argument("--weight_decay", type=float, default=1.,
370 | help="Weight decay. Switched off by default.")
371 | parser.add_argument("--epochs", type=int, default=1000000,
372 | help="Total number of training epochs")
373 | parser.add_argument("--epochs_warmup", type=int,
374 | default=10, help="Warmup epochs")
375 | parser.add_argument("--epochs_full_valid", type=int,
376 | default=50, help="Epochs between valid")
377 | parser.add_argument("--gradient_checkpointing", type=int,
378 | default=1, help="Use memory saving gradients")
379 |
380 | # Model hyperparams:
381 | parser.add_argument("--image_size", type=int,
382 | default=-1, help="Image size")
383 | parser.add_argument("--anchor_size", type=int, default=32,
384 | help="Anchor size for deciding batch size")
385 | parser.add_argument("--width", type=int, default=512,
386 | help="Width of hidden layers")
387 | parser.add_argument("--depth", type=int, default=32,
388 | help="Depth of network")
389 | parser.add_argument("--weight_y", type=float, default=0.00,
390 | help="Weight of log p(y|x) in weighted loss")
391 | parser.add_argument("--n_bits_x", type=int, default=8,
392 | help="Number of bits of x")
393 | parser.add_argument("--n_levels", type=int, default=3,
394 | help="Number of levels")
395 |
396 | # Synthesis/Sampling hyperparameters:
397 | parser.add_argument("--n_sample", type=int, default=1,
398 | help="minibatch size for sample")
399 | parser.add_argument("--epochs_full_sample", type=int,
400 | default=50, help="Epochs between full scale sample")
401 |
402 | # Ablation
403 | parser.add_argument("--learntop", action="store_true",
404 | help="Learn spatial prior")
405 | parser.add_argument("--ycond", action="store_true",
406 | help="Use y conditioning")
407 | parser.add_argument("--seed", type=int, default=0, help="Random seed")
408 | parser.add_argument("--flow_permutation", type=int, default=2,
409 | help="Type of flow. 0=reverse (realnvp), 1=shuffle, 2=invconv (ours)")
410 | parser.add_argument("--flow_coupling", type=int, default=0,
411 | help="Coupling type: 0=additive, 1=affine")
412 |
413 | hps = parser.parse_args() # So error if typo
414 | main(hps)
415 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | class ResultLogger(object):
5 | def __init__(self, path, *args, **kwargs):
6 | self.f_log = open(path, 'w')
7 | self.f_log.write(json.dumps(kwargs) + '\n')
8 |
9 | def log(self, **kwargs):
10 | self.f_log.write(json.dumps(kwargs) + '\n')
11 | self.f_log.flush()
12 |
13 | def close(self):
14 | self.f_log.close()
15 |
--------------------------------------------------------------------------------