├── .DS_Store ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── FAQ.md ├── LICENSE ├── README.md ├── c2pa ├── README.md ├── c2pa_watermark_example.py ├── example.jpg ├── example_wm.jpg ├── example_wm_signed.jpg ├── keys │ ├── README.md │ ├── es256_certs.pem │ └── es256_private.key └── manifest.json ├── images ├── ghost.png ├── ghost_P.png ├── ghost_Q.png ├── ripley.jpg ├── ripley_P.png ├── ripley_Q.png ├── ufo_240.jpg ├── ufo_240_P.png └── ufo_240_Q.png ├── js ├── README.md ├── deps │ ├── bch_ecc.min.js │ ├── ort-wasm-simd-threaded.jsep.mjs │ ├── ort-wasm-simd-threaded.jsep.wasm │ ├── ort-wasm-simd-threaded.mjs │ ├── ort-wasm-simd-threaded.wasm │ ├── ort.webgpu.min.js │ ├── ort.webgpu.min.js.map │ ├── ort.webgpu.min.mjs │ └── ort.webgpu.min.mjs.map ├── index.html ├── tm_datalayer.js └── tm_watermark.js ├── python ├── CONFIG.md ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test.py └── trustmark │ ├── KBNet │ ├── README.md │ ├── arch_util.py │ ├── kb_utils.py │ ├── kbnet.py │ ├── kbnet_l_arch.py │ └── kbnet_s_arch.py │ ├── __init__.py │ ├── bchecc.py │ ├── datalayer.py │ ├── denoise.py │ ├── model.py │ ├── models │ └── README.md │ ├── trustmark.py │ └── unet.py └── rust ├── .cargo └── config.toml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── benches ├── encode.rs ├── encode.sh ├── load.rs └── load.sh ├── crates ├── trustmark-cli │ ├── Cargo.toml │ ├── README.md │ └── src │ │ └── main.rs └── xtask │ ├── Cargo.toml │ └── src │ └── main.rs ├── models └── .gitkeep └── src ├── bits.rs ├── bits └── bch.rs ├── image_processing.rs ├── lib.rs └── model.rs /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/.DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Adobe Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our project and community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. 6 | 7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 8 | 9 | ## Our Standards 10 | 11 | Examples of behavior that contribute to a positive environment for our project and community include: 12 | 13 | * Demonstrating empathy and kindness toward other people 14 | * Being respectful of differing opinions, viewpoints, and experiences 15 | * Giving and gracefully accepting constructive feedback 16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 17 | * Focusing on what is best, not just for us as individuals but for the overall community 18 | 19 | Examples of unacceptable behavior include: 20 | 21 | * The use of sexualized language or imagery, and sexual attention or advances of any kind 22 | * Trolling, insulting or derogatory comments, and personal or political attacks 23 | * Public or private harassment 24 | * Publishing others’ private information, such as a physical or email address, without their explicit permission 25 | * Other conduct which could reasonably be considered inappropriate in a professional setting 26 | 27 | ## Our Responsibilities 28 | 29 | Project maintainers are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any instances of unacceptable behavior. 30 | 31 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for behaviors that they deem inappropriate, threatening, offensive, or harmful. 32 | 33 | ## Scope 34 | 35 | This Code of Conduct applies when an individual is representing the project or its community both within project spaces and in public spaces. Examples of representing a project or community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 36 | 37 | ## Enforcement 38 | 39 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by first contacting the project team. Oversight of Adobe projects is handled by the Adobe Open Source Office, which has final say in any violations and enforcement of this Code of Conduct and can be reached at Grp-opensourceoffice@adobe.com. All complaints will be reviewed and investigated promptly and fairly. 40 | 41 | The project team must respect the privacy and security of the reporter of any incident. 42 | 43 | Project maintainers who do not follow or enforce the Code of Conduct may face temporary or permanent repercussions as determined by other members of the project's leadership or the Adobe Open Source Office. 44 | 45 | ## Enforcement Guidelines 46 | 47 | Project maintainers will follow these Community Impact Guidelines in determining the consequences for any action they deem to be in violation of this Code of Conduct: 48 | 49 | **1. Correction** 50 | 51 | Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 52 | 53 | Consequence: A private, written warning from project maintainers describing the violation and why the behavior was unacceptable. A public apology may be requested from the violator before any further involvement in the project by violator. 54 | 55 | **2. Warning** 56 | 57 | Community Impact: A relatively minor violation through a single incident or series of actions. 58 | 59 | Consequence: A written warning from project maintainers that includes stated consequences for continued unacceptable behavior. Violator must refrain from interacting with the people involved for a specified period of time as determined by the project maintainers, including, but not limited to, unsolicited interaction with those enforcing the Code of Conduct through channels such as community spaces and social media. Continued violations may lead to a temporary or permanent ban. 60 | 61 | **3. Temporary Ban** 62 | 63 | Community Impact: A more serious violation of community standards, including sustained unacceptable behavior. 64 | 65 | Consequence: A temporary ban from any interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Failure to comply with the temporary ban may lead to a permanent ban. 66 | 67 | **4. Permanent Ban** 68 | 69 | Community Impact: Demonstrating a consistent pattern of violation of community standards or an egregious violation of community standards, including, but not limited to, sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 70 | 71 | Consequence: A permanent ban from any interaction with the community. 72 | 73 | ## Attribution 74 | 75 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, 76 | available at [https://contributor-covenant.org/version/2/1][version] 77 | 78 | [homepage]: https://contributor-covenant.org 79 | [version]: https://contributor-covenant.org/version/2/1 80 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for choosing to contribute! 4 | 5 | The following are a set of guidelines to follow when contributing to this project. 6 | 7 | ## Code Of Conduct 8 | 9 | This project adheres to the Adobe [code of conduct](../CODE_OF_CONDUCT.md). By participating, 10 | you are expected to uphold this code. Please report unacceptable behavior to 11 | [Grp-opensourceoffice@adobe.com](mailto:Grp-opensourceoffice@adobe.com). 12 | 13 | ## Have A Question? 14 | 15 | Start by filing an issue. The existing committers on this project work to reach 16 | consensus around project direction and issue solutions within issue threads 17 | (when appropriate). 18 | 19 | ## Contributor License Agreement 20 | 21 | All third-party contributions to this project must be accompanied by a signed contributor 22 | license agreement. This gives Adobe permission to redistribute your contributions 23 | as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You 24 | only need to submit an Adobe CLA one time, so if you have submitted one previously, 25 | you are good to go! 26 | 27 | ## Code Reviews 28 | 29 | All submissions should come in the form of pull requests and need to be reviewed 30 | by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/) 31 | for more information on sending pull requests. 32 | 33 | Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when 34 | submitting a pull request! 35 | 36 | ## From Contributor To Committer 37 | 38 | We love contributions from our community! If you'd like to go a step beyond contributor 39 | and become a committer with full write access and a say in the project, you must 40 | be invited to the project. The existing committers employ an internal nomination 41 | process that must reach lazy consensus (silence is approval) before invitations 42 | are issued. If you feel you are qualified and want to get more deeply involved, 43 | feel free to reach out to existing committers to have a conversation about that. 44 | 45 | ## Security Issues 46 | 47 | Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html). 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Adobe 2 | All Rights Reserved. 3 | 4 | NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | accordance with the terms of the license agreement accompanying it. 6 | 7 | MIT License 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrustMark 2 | 3 | This repository contains the official, open source implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2]. 4 | 5 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK. 6 | 7 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA. 8 | 9 | ## Overview 10 | 11 | This repository contains the following directories: 12 | 13 | - `/python`: Python implementation of TrustMark for encoding, decoding and removing image watermarks (using PyTorch). For information on configuring TrustMark in Python, see [Configuring TrustMark](python/CONFIG.md). 14 | - `/js`: Javascript implementation of TrustMark decoding of image watermarks (using ONNX). For more information, see [TrustMark - JavaScript implementation](js/README.md). 15 | - `/rust`: Rust implementation of TrustMark. for more information, see [TrustMark — Rust implementation](rust/README.md). 16 | - `/c2pa`: Python example of how to indicate the presence of a TrustMark watermark in a C2PA manifest. For more information, see [Using TrustMark with C2PA](c2pa/README.md). 17 | 18 | Model files (**ckpt** PyTorch file for Python and **onnx** ONNX file for JavaScript) are not packaged in this repository due to their size, but are downloaded upon first use. See the code for [URLs and md5 hashes](https://github.com/adobe/trustmark/blob/4ef0dde4abd84d1c6873e7c5024482f849db2c73/python/trustmark/trustmark.py#L30) for a direct download link. 19 | 20 | More information: 21 | 22 | - For answers to common questions, see the [FAQ](FAQ.md). 23 | - For information on configuring TrustMark in Python, see [Configuring TrustMark](python/CONFIG.md). 24 | 25 | ## Installation 26 | 27 | ### Prerequisite 28 | 29 | You must have Python 3.8.5 or higher to use the TrustMark Python implementation. 30 | 31 | ### Installing from PyPI 32 | 33 | The easiest way to install TrustMark is from the [Python Package Index (PyPI)](https://pypi.org/project/trustmark/) by entering this command: 34 | 35 | ``` 36 | pip install trustmark 37 | ``` 38 | 39 | Alternatively, after you've cloned the repository, you can install from the `python` directory: 40 | 41 | ``` 42 | cd trustmark/python 43 | pip install . 44 | ``` 45 | 46 | ## Quickstart 47 | 48 | To get started quickly, run the `python/test.py` script that provides examples of watermarking several 49 | image files from the `images` directory. 50 | 51 | ### Run the example 52 | 53 | Run the example as follows: 54 | 55 | ```sh 56 | cd trustmark/python 57 | python test.py 58 | ``` 59 | 60 | You'll see output like this: 61 | 62 | ``` 63 | Initializing TrustMark watermarking with ECC using [cpu] 64 | Extracted secret: 1000000100001110000010010001011110010001011000100000100110110 (schema 1) 65 | PSNR = 50.357909 66 | No secret after removal 67 | ``` 68 | 69 | ### Example script 70 | 71 | The `python/test.py` script provides examples of watermarking a JPEG photo, a JPEG GenAI image, and an RGBA PNG image. The example uses TrustMark variant Q to encode the word `mysecret` in ASCII7 encoding into the image `ufo_240.jpg` which is then decoded, and then removed from the image. 72 | 73 | ```python 74 | from trustmark import TrustMark 75 | from PIL import Image 76 | 77 | # init 78 | tm=TrustMark(verbose=True, model_type='Q') # or try P 79 | 80 | # encoding example 81 | cover = Image.open('images/ufo_240.jpg').convert('RGB') 82 | tm.encode(cover, 'mysecret').save('ufo_240_Q.png') 83 | 84 | # decoding example 85 | cover = Image.open('images/ufo_240_Q.png').convert('RGB') 86 | wm_secret, wm_present, wm_schema = tm.decode(cover) 87 | 88 | if wm_present: 89 | print(f'Extracted secret: {wm_secret}') 90 | else: 91 | print('No watermark detected') 92 | 93 | # removal example 94 | stego = Image.open('images/ufo_240_Q.png').convert('RGB') 95 | im_recover = tm.remove_watermark(stego) 96 | im_recover.save('images/recovered.png') 97 | ``` 98 | 99 | ## GPU setup 100 | 101 | TrustMark runs well on CPU hardware. 102 | 103 | To leverage GPU compute for the PyTorch implementation on Ubuntu Linux, first install Conda, then use the following commands to install: 104 | 105 | ```sh 106 | conda create --name trustmark python=3.10 107 | conda activate trustmark 108 | conda install pytorch cudatoolkit=12.8 -c pytorch -c conda-forge 109 | pip install torch==2.1.2 torchvision==0.16.2 -f https://download.pytorch.org/whl/torch_stable.html 110 | pip install . 111 | ``` 112 | 113 | For the JavaScript implementation, a Chromium browser automatically uses WebGPU, if available. 114 | 115 | ## Data schema 116 | 117 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits. 118 | You can configure an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise. 119 | 120 | In payload encoding, the version bits comprise two reserved (unused) bits, and two bits encoding an integer value 0-3 that specifies the data schema as follows: 121 | - 0: BCH_SUPER 122 | - 1: BCH_5 123 | - 2: BCH_4 124 | - 3: BCH_3 125 | 126 | For more details and information on configuring the encoding mode in Python, see [Configuring TrustMark](python/CONFIG.md). 127 | 128 | ## Citation 129 | 130 | If you find this work useful, please cite the repository and/or TrustMark paper as follows: 131 | 132 | ``` 133 | @article{trustmark, 134 | title={Trustmark: Universal Watermarking for Arbitrary Resolution Images}, 135 | author={Bui, Tu and Agarwal, Shruti and Collomosse, John}, 136 | journal = {ArXiv e-prints}, 137 | archivePrefix = "arXiv", 138 | eprint = {2311.18297}, 139 | year = 2023, 140 | month = nov 141 | } 142 | ``` 143 | 144 | ## License 145 | 146 | This package is is distributed under the terms of the [MIT license](https://github.com/adobe/trustmark/blob/main/LICENSE). 147 | -------------------------------------------------------------------------------- /c2pa/README.md: -------------------------------------------------------------------------------- 1 | # Using TrustMark with C2PA 2 | 3 | Open standards such as Content Credentials, developed by the [Coalition for Content Provenance and Authenticity(C2PA)](https://c2pa.org/), describe ways to encode information about an image’s history or _provenance_, such as how and when it was made. This information is usually carried within the image’s metadata. 4 | 5 | ## Durable Content Credentials 6 | 7 | C2PA manifest data can be accidentally removed when the image is shared through platforms that do not yet support the standard. If a copy of the manifest data is retained in a database, the TrustMark identifier carried inside the watermark can be used as a key to look up that information from the database. This is referred to as a [_Durable Content Credential_](https://contentauthenticity.org/blog/durable-content-credentials) and the technical term for the identifier is a _soft binding_. 8 | 9 | To create a soft binding, TrustMark encodes a random identifier via one of the encoding types in the [data schema](../README.md#data-schema). For example, [`c2pa/c2pa_watermark_example.py`](https://github.com/adobe/trustmark/blob/main/c2pa/c2pa_watermark_example.py) shows how to reflect the identifier within the C2PA manifest using a _soft binding assertion_. 10 | 11 | For more information, see the [FAQ](../FAQ.md#how-does-trustmark-align-with-provenance-standards-such-as-the-c2pa). 12 | 13 | ## Signpost watermark 14 | 15 | TrustMark [coexists well with most other image watermarks](https://arxiv.org/abs/2501.17356) and so can be used as a _signpost_ to indicate the co-presence of another watermarking technology. This can be helpful, sinace as an open technology, TrustMark can be used to indicate (signpost) which decoder to obtain and run on an image to decode a soft binding identifier for C2PA. 16 | 17 | In this mode the encoding should be `Encoding.BCH_SUPER` and the payload contain an integer identifier that describes the co-present watermark. The integer should be taken from the registry of C2PA-approved watermarks listed in this normative C2PA [softbinding-algorithms-list](https://github.com/c2pa-org/softbinding-algorithms-list) repository. 18 | -------------------------------------------------------------------------------- /c2pa/c2pa_watermark_example.py: -------------------------------------------------------------------------------- 1 | import random,os 2 | from PIL import Image 3 | import json 4 | import struct 5 | 6 | from trustmark import TrustMark 7 | 8 | from PIL import Image 9 | 10 | TM_SCHEMA_CODE=TrustMark.Encoding.BCH_4 11 | 12 | def uuidgen(bitlen): 13 | 14 | id = ''.join(random.choice('01') for _ in range(bitlen)) 15 | return id 16 | 17 | 18 | def embed_watermark(img_in, img_out, watermarkID, tm): 19 | cover = Image.open(img_in) 20 | rgb=cover.convert('RGB') 21 | has_alpha=cover.mode== 'RGBA' 22 | if (has_alpha): 23 | alpha=cover.split()[-1] 24 | encoded=tm.encode(rgb, watermarkID, MODE='binary') 25 | params={ 26 | "exif":cover.info.get('exif'), 27 | "icc_profile":cover.info.get('icc_profile'), 28 | "dpi":cover.info.get('dpi') 29 | } 30 | not_none_params = {k:v for k, v in params.items() if v is not None} 31 | encoded.save(img_out, **not_none_params) 32 | 33 | def build_manifest(watermarkID, img_in): 34 | 35 | assertions=[] 36 | assertions.append(build_softbinding('com.adobe.trustmark.Q',str(TM_SCHEMA_CODE)+"*"+watermarkID)) 37 | actions=[] 38 | act=dict() 39 | act['action']='c2pa.watermarked' 40 | actions.append(act) 41 | 42 | manifest=dict() 43 | manifest['claim_generator']="python_trustmark/c2pa" 44 | manifest['title']="Watermarked Image" 45 | manifest['thumbnail']=dict() 46 | manifest['ingredient_paths']=[img_in] 47 | 48 | ext=img_in.split('.')[-1] 49 | manifest['thumbnail']['format']="image/"+ext 50 | manifest['thumbnail']['identifier']=img_in 51 | manifest['assertions']=assertions 52 | manifest['actions']=actions 53 | 54 | return manifest 55 | 56 | def build_softbinding(alg,val): 57 | sba=dict() 58 | sba['label']='c2pa.soft-binding' 59 | sba['data']=dict() 60 | sba['data']['alg']=alg 61 | sba['data']['blocks']=list() 62 | blk=dict() 63 | blk['scope']=dict() 64 | blk['value']=val 65 | sba['data']['blocks'].append(blk) 66 | return sba 67 | 68 | def manifest_add_signing(mf): 69 | mf['alg']='es256' 70 | mf['ta_url']='http://timestamp.digicert.com' 71 | mf['private_key']='keys/es256_private.key' 72 | mf['sign_cert']='keys/es256_certs.pem' 73 | return mf 74 | 75 | def manifest_add_creator(mf,name): 76 | cwa=dict() 77 | cwa['label']='stds.schema-org.CreativeWork' 78 | cwa['data']=dict() 79 | cwa['data']['@context']='https://schema.org' 80 | cwa['data']['@type']='CreativeWork' 81 | author=dict() 82 | author['@type']='Person' 83 | author['name']=name 84 | cwa['data']['author']=[author] 85 | mf['assertions'].append(cwa) 86 | return mf 87 | 88 | 89 | def main() : 90 | 91 | img_in='example.jpg' 92 | img_out='example_wm.jpg' 93 | img_out_signed='example_wm_signed.jpg' 94 | 95 | # Generate a random watermark ID 96 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TM_SCHEMA_CODE) 97 | bitlen=tm.schemaCapacity() 98 | id=uuidgen(bitlen) 99 | 100 | # Encode watermark 101 | embed_watermark(img_in, img_out, id, tm) 102 | 103 | # Build manifest 104 | mf=build_manifest(id, img_in) 105 | mf=manifest_add_creator(mf,"Walter Mark") 106 | mf=manifest_add_signing(mf) 107 | 108 | 109 | fp=open('manifest.json','wt') 110 | fp.write(json.dumps(mf, indent=4)) 111 | fp.close() 112 | os.system('c2patool '+img_out+' -m manifest.json -f -o '+img_out_signed) 113 | 114 | 115 | 116 | # Check watermark present 117 | stego = Image.open(img_out_signed).convert('RGB') 118 | wm_id, wm_present, wm_schema = tm.decode(stego, 'binary') 119 | if wm_present: 120 | print('Watermark detected in signed image') 121 | if wm_id==id: 122 | print('Watermark is correct') 123 | print(id) 124 | else: 125 | print('Watermark does not match!') 126 | print(id) 127 | print(wm_id) 128 | else: 129 | print('No watermark detected!') 130 | 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | 136 | 137 | -------------------------------------------------------------------------------- /c2pa/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example.jpg -------------------------------------------------------------------------------- /c2pa/example_wm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example_wm.jpg -------------------------------------------------------------------------------- /c2pa/example_wm_signed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example_wm_signed.jpg -------------------------------------------------------------------------------- /c2pa/keys/README.md: -------------------------------------------------------------------------------- 1 | # CAI Test Signing Keys 2 | 3 | These public test keys are updated from time to time, and taken December 2023 from https://github.com/contentauth/c2patool/tree/main/sample 4 | 5 | Please source new keys if expired. 6 | 7 | These are not secrets. 8 | -------------------------------------------------------------------------------- /c2pa/keys/es256_certs.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIChzCCAi6gAwIBAgIUcCTmJHYF8dZfG0d1UdT6/LXtkeYwCgYIKoZIzj0EAwIw 3 | gYwxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJU29tZXdoZXJl 4 | MScwJQYDVQQKDB5DMlBBIFRlc3QgSW50ZXJtZWRpYXRlIFJvb3QgQ0ExGTAXBgNV 5 | BAsMEEZPUiBURVNUSU5HX09OTFkxGDAWBgNVBAMMD0ludGVybWVkaWF0ZSBDQTAe 6 | Fw0yMjA2MTAxODQ2NDBaFw0zMDA4MjYxODQ2NDBaMIGAMQswCQYDVQQGEwJVUzEL 7 | MAkGA1UECAwCQ0ExEjAQBgNVBAcMCVNvbWV3aGVyZTEfMB0GA1UECgwWQzJQQSBU 8 | ZXN0IFNpZ25pbmcgQ2VydDEZMBcGA1UECwwQRk9SIFRFU1RJTkdfT05MWTEUMBIG 9 | A1UEAwwLQzJQQSBTaWduZXIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQPaL6R 10 | kAkYkKU4+IryBSYxJM3h77sFiMrbvbI8fG7w2Bbl9otNG/cch3DAw5rGAPV7NWky 11 | l3QGuV/wt0MrAPDoo3gwdjAMBgNVHRMBAf8EAjAAMBYGA1UdJQEB/wQMMAoGCCsG 12 | AQUFBwMEMA4GA1UdDwEB/wQEAwIGwDAdBgNVHQ4EFgQUFznP0y83joiNOCedQkxT 13 | tAMyNcowHwYDVR0jBBgwFoAUDnyNcma/osnlAJTvtW6A4rYOL2swCgYIKoZIzj0E 14 | AwIDRwAwRAIgOY/2szXjslg/MyJFZ2y7OH8giPYTsvS7UPRP9GI9NgICIDQPMKrE 15 | LQUJEtipZ0TqvI/4mieoyRCeIiQtyuS0LACz 16 | -----END CERTIFICATE----- 17 | -----BEGIN CERTIFICATE----- 18 | MIICajCCAg+gAwIBAgIUfXDXHH+6GtA2QEBX2IvJ2YnGMnUwCgYIKoZIzj0EAwIw 19 | dzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQHDAlTb21ld2hlcmUx 20 | GjAYBgNVBAoMEUMyUEEgVGVzdCBSb290IENBMRkwFwYDVQQLDBBGT1IgVEVTVElO 21 | R19PTkxZMRAwDgYDVQQDDAdSb290IENBMB4XDTIyMDYxMDE4NDY0MFoXDTMwMDgy 22 | NzE4NDY0MFowgYwxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJ 23 | U29tZXdoZXJlMScwJQYDVQQKDB5DMlBBIFRlc3QgSW50ZXJtZWRpYXRlIFJvb3Qg 24 | Q0ExGTAXBgNVBAsMEEZPUiBURVNUSU5HX09OTFkxGDAWBgNVBAMMD0ludGVybWVk 25 | aWF0ZSBDQTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHllI4O7a0EkpTYAWfPM 26 | D6Rnfk9iqhEmCQKMOR6J47Rvh2GGjUw4CS+aLT89ySukPTnzGsMQ4jK9d3V4Aq4Q 27 | LsOjYzBhMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQDAgGGMB0GA1UdDgQW 28 | BBQOfI1yZr+iyeUAlO+1boDitg4vazAfBgNVHSMEGDAWgBRembiG4Xgb2VcVWnUA 29 | UrYpDsuojDAKBggqhkjOPQQDAgNJADBGAiEAtdZ3+05CzFo90fWeZ4woeJcNQC4B 30 | 84Ill3YeZVvR8ZECIQDVRdha1xEDKuNTAManY0zthSosfXcvLnZui1A/y/DYeg== 31 | -----END CERTIFICATE----- 32 | 33 | -------------------------------------------------------------------------------- /c2pa/keys/es256_private.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgfNJBsaRLSeHizv0m 3 | GL+gcn78QmtfLSm+n+qG9veC2W2hRANCAAQPaL6RkAkYkKU4+IryBSYxJM3h77sF 4 | iMrbvbI8fG7w2Bbl9otNG/cch3DAw5rGAPV7NWkyl3QGuV/wt0MrAPDo 5 | -----END PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /c2pa/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "claim_generator": "python_trustmark/c2pa", 3 | "title": "Watermarked Image", 4 | "thumbnail": { 5 | "format": "image/jpg", 6 | "identifier": "example.jpg" 7 | }, 8 | "ingredient_paths": [ 9 | "example.jpg" 10 | ], 11 | "assertions": [ 12 | { 13 | "label": "c2pa.soft-binding", 14 | "data": { 15 | "alg": "com.adobe.trustmark.Q", 16 | "blocks": [ 17 | { 18 | "scope": {}, 19 | "value": "2*00000010010100000100001111011011010011100010011101000010100000001110" 20 | } 21 | ] 22 | } 23 | }, 24 | { 25 | "label": "stds.schema-org.CreativeWork", 26 | "data": { 27 | "@context": "https://schema.org", 28 | "@type": "CreativeWork", 29 | "author": [ 30 | { 31 | "@type": "Person", 32 | "name": "Walter Mark" 33 | } 34 | ] 35 | } 36 | } 37 | ], 38 | "actions": [ 39 | { 40 | "action": "c2pa.watermarked" 41 | } 42 | ], 43 | "alg": "es256", 44 | "ta_url": "http://timestamp.digicert.com", 45 | "private_key": "keys/es256_private.key", 46 | "sign_cert": "keys/es256_certs.pem" 47 | } -------------------------------------------------------------------------------- /images/ghost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost.png -------------------------------------------------------------------------------- /images/ghost_P.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost_P.png -------------------------------------------------------------------------------- /images/ghost_Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost_Q.png -------------------------------------------------------------------------------- /images/ripley.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley.jpg -------------------------------------------------------------------------------- /images/ripley_P.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley_P.png -------------------------------------------------------------------------------- /images/ripley_Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley_Q.png -------------------------------------------------------------------------------- /images/ufo_240.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240.jpg -------------------------------------------------------------------------------- /images/ufo_240_P.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240_P.png -------------------------------------------------------------------------------- /images/ufo_240_Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240_Q.png -------------------------------------------------------------------------------- /js/README.md: -------------------------------------------------------------------------------- 1 | # TrustMark JavaScript example 2 | 3 | The [`js`](https://github.com/adobe/trustmark/tree/main/js) directory contains an example JavaScript implementation of decoding TrustMark watermarks embedded in images. It provides a minimal example of a client-side JavaScript application, would could be applied, for example, to a browser extension. 4 | 5 | ## Overview 6 | 7 | The example consists of key modules that handle image preprocessing, watermark detection, and data decoding. It supports the `Q` variant of the TrustMark watermarking schema. 8 | 9 | NOTE: The TrustMark JavaScript implementation only decodes watermarked images, in contrast to the Python implementation, which can both encode and decode. 10 | 11 | ## Components 12 | 13 | This example consists of a simple HTML file, `index.html` that loads two JavaScript files: 14 | 15 | - [`tm_watermark.js`](https://github.com/adobe/trustmark/blob/main/js/tm_watermark.js), the core module that handles watermark detection and decoding and defines key functions for processing images and extracting watermark data. The `modelConfigs` array specifies parameters such as the TrustMark model variant. As provided, the code checks for both Q and P variants. 16 | - [`tm_datalayer.js`](https://github.com/adobe/trustmark/blob/main/js/tm_datalayer.js) handles data decoding and schema-specific processing. It also implements error correction and interpretation of binary watermark data. 17 | 18 | If GPU compute is available (if you're using Google Chrome, check `chrome://gpu`), then the code will automatically use WebGPU to process the ONNX models. If you use WebGPU it will only run in a secure context, which means on localhost or an HTTPS link. You can start a local HTTPS server by running the `server.py` script and a suitable OpenSSL certificate in `server.pem`. 19 | 20 | ## Key parameters 21 | 22 | The desired TrustMark watermark variants for decoding are listed in the `modelConfigs` array at the top of `tm_watermark.js` for example, B, C, Q and P variants: 23 | 24 | ```js 25 | const modelConfigs = [ 26 | { variantcode: 'Q', fname: 'decoder_Q', sessionVar: 'session_wmarkQ', resolution: 256, squarecrop: false }, 27 | { variantcode: 'P', fname: 'decoder_P', sessionVar: 'session_wmarkP', resolution: 224, squarecrop: true }, 28 | ]; 29 | ``` 30 | 31 | ## Run the example 32 | 33 | To run the example: 34 | 35 | 1. Start a local web server; for example, using Python: 36 | ``` 37 | python -m http.server 8000 38 | ``` 39 | 1. Open `index.html` in a browser to run the example. 40 | 1. Drag and drop images (for example from the provided [`images`](https://github.com/adobe/trustmark/tree/main/images) directory) onto the indicated area in the web page, which will display information if the image contains a TrustMark watermark (see example below). 41 | 42 | To use the code in your own project, simply include the two JavaScript files as usual. 43 | 44 | ## Example output 45 | 46 | For an image containing a TrustMark watermark: 47 | 48 | ``` 49 | [4:12:48 PM] Decoding watermark... 50 | [4:12:48 PM] Watermark Found (BCH_5): 51 | 1101101111100111100111100101110001111100101100101111010000000 52 | C2PA Assertion: 53 | { 54 | "c2pa.soft-binding": { 55 | "alg": "com.adobe.trustmark.Q", 56 | "blocks": [ 57 | { 58 | "scope": {}, 59 | "value": "1*1101101111100111100111100101110001111100101100101111010000000" 60 | } 61 | ] 62 | } 63 | } 64 | ``` 65 | 66 | -------------------------------------------------------------------------------- /js/deps/ort-wasm-simd-threaded.jsep.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/js/deps/ort-wasm-simd-threaded.jsep.wasm -------------------------------------------------------------------------------- /js/deps/ort-wasm-simd-threaded.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/js/deps/ort-wasm-simd-threaded.wasm -------------------------------------------------------------------------------- /js/index.html: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 12 | 13 | TrustMark JS / Client-side Decode Demonstrator 14 | 15 | 16 | 17 | 18 | 124 | 125 | 126 | 127 |

TrustMark JS / Client-side Decode Demonstrator

128 |
129 |
130 |

Drag & drop your image here

131 |

— or —

132 | 133 |
134 | 135 | 136 | 137 | 138 | 139 | 143 |
144 | 145 | 272 | 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /js/tm_datalayer.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * TrustMark Data Layer module 3 | * Copyright 2024 Adobe. All rights reserved. 4 | * Licensed under the MIT License. 5 | * 6 | * NOTICE: Adobe permits you to use, modify, and distribute this file in 7 | * accordance with the terms of the Adobe license agreement accompanying it. 8 | */ 9 | 10 | // Initialize ECC engines for all supported versions 11 | let eccengine = []; 12 | for (let version = 0; version < 4; version++) { 13 | eccengine.push(DataLayer_GetECCEngine(version)); 14 | } 15 | 16 | /** 17 | * Decodes the watermark data using the given ECC engine and schema. 18 | * Attempts fallback decoding with alternate schemas if the primary attempt fails. 19 | * 20 | * @param {Array} watermarkbool - The boolean array representing the watermark. 21 | * @param {Array} eccengine - Array of ECC engines for decoding. 22 | * @returns {Object} Decoded watermark data with schema and soft binding info. 23 | */ 24 | function DataLayer_Decode(watermarkbool, eccengine, variant) { 25 | let version = DataLayer_GetVersion(watermarkbool); 26 | let databits = DataLayer_GetSchemaDataBits(version); 27 | 28 | let data = watermarkbool.slice(0, databits); 29 | let ecc = watermarkbool.slice(databits, 96); 30 | 31 | let dataobj = BCH_Decode(eccengine[version], data, ecc); 32 | dataobj.schema = DataLayer_GetSchemaName(version); 33 | 34 | if (!dataobj.valid) { 35 | // Attempt decoding with alternate schemas 36 | for (let alt = 0; alt < 3; alt++) { 37 | if (alt === version) continue; 38 | 39 | databits = DataLayer_GetSchemaDataBits(alt); 40 | data = watermarkbool.slice(0, databits); 41 | ecc = watermarkbool.slice(databits, 96); 42 | 43 | dataobj = BCH_Decode(eccengine[alt], data, ecc); 44 | dataobj.schema = DataLayer_GetSchemaName(alt); 45 | if (dataobj.valid) break; 46 | } 47 | } 48 | 49 | // Add soft binding information 50 | dataobj.softBindingInfo = formatSoftBindingData(dataobj.data_binary, version, variant); 51 | return dataobj; 52 | } 53 | 54 | /** 55 | * Interprets the watermark data in the context of C2PA. 56 | * 57 | * @param {Object} dataobj - The decoded watermark data object. 58 | * @param {number} version - The version of the schema. 59 | * @returns {Promise} Promise resolving with the updated data object. 60 | */ 61 | function interpret_C2PA(dataobj, version) { 62 | return new Promise((resolve, reject) => { 63 | if (true) { // Placeholder for schema-specific logic 64 | fetchSoftBindingInfo(dataobj.data_binary) 65 | .then(softBindingInfo => { 66 | if (softBindingInfo) { 67 | dataobj.softBindingInfo = softBindingInfo; 68 | } else { 69 | console.warn("No soft binding info found."); 70 | } 71 | resolve(dataobj); 72 | }) 73 | .catch(error => { 74 | console.error("Error fetching soft binding info:", error); 75 | reject(error); 76 | }); 77 | } else { 78 | resolve(dataobj); 79 | } 80 | }); 81 | } 82 | 83 | /** 84 | * Extracts the schema version from the last two bits of the watermark boolean array. 85 | * 86 | * @param {Array} watermarkbool - The boolean array representing the watermark. 87 | * @returns {number} The schema version as an integer. 88 | */ 89 | function DataLayer_GetVersion(watermarkbool) { 90 | watermarkbool = watermarkbool.slice(-2); 91 | return watermarkbool[0] * 2 + watermarkbool[1]; 92 | } 93 | 94 | /** 95 | * Retrieves the ECC engine for the given schema version. 96 | * 97 | * @param {number} version - The schema version. 98 | * @returns {Object} The corresponding ECC engine. 99 | */ 100 | function DataLayer_GetECCEngine(version) { 101 | switch (version) { 102 | case 0: 103 | return BCH(8, 137); 104 | case 1: 105 | return BCH(5, 137); 106 | case 2: 107 | return BCH(4, 137); 108 | case 3: 109 | return BCH(3, 137); 110 | default: 111 | return -1; 112 | } 113 | } 114 | 115 | /** 116 | * Retrieves the number of data bits for the given schema version. 117 | * 118 | * @param {number} version - The schema version. 119 | * @returns {number} The number of data bits. 120 | */ 121 | function DataLayer_GetSchemaDataBits(version) { 122 | switch (version) { 123 | case 0: 124 | return 40; 125 | case 1: 126 | return 61; 127 | case 2: 128 | return 68; 129 | case 3: 130 | return 75; 131 | default: 132 | console.error("Invalid schema version"); 133 | return 0; 134 | } 135 | } 136 | 137 | /** 138 | * Retrieves the name of the schema for the given version. 139 | * 140 | * @param {number} version - The schema version. 141 | * @returns {string} The schema name. 142 | */ 143 | function DataLayer_GetSchemaName(version) { 144 | switch (version) { 145 | case 0: 146 | return "BCH_SUPER"; 147 | case 1: 148 | return "BCH_5"; 149 | case 2: 150 | return "BCH_4"; 151 | case 3: 152 | return "BCH_3"; 153 | default: 154 | return "Invalid"; 155 | } 156 | } 157 | 158 | /** 159 | * Formats the encoded watermark data into a structured JSON object. 160 | * 161 | * @param {Array} encodedData - The binary data representing the watermark. 162 | * @param {number} version - The schema version. 163 | * @returns {Object|null} Formatted JSON object or null in case of errors. 164 | */ 165 | function formatSoftBindingData(encodedData, version, variant) { 166 | try { 167 | const binaryString = Array.isArray(encodedData) 168 | ? encodedData.join('') 169 | : String(encodedData); 170 | 171 | return { 172 | "c2pa.soft-binding": { 173 | "alg": `com.adobe.trustmark.${variant}`, 174 | "blocks": [ 175 | { 176 | "scope": {}, 177 | "value": `${version}*${binaryString}` 178 | } 179 | ] 180 | } 181 | }; 182 | } catch (error) { 183 | console.error("Error formatting soft binding data:", error); 184 | return null; 185 | } 186 | } 187 | 188 | -------------------------------------------------------------------------------- /js/tm_watermark.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * TrustMark JS Watermarking Decoder Module 3 | * Copyright 2024 Adobe. All rights reserved. 4 | * Licensed under the MIT License. 5 | * 6 | * NOTICE: Adobe permits you to use, modify, and distribute this file in 7 | * accordance with the terms of the Adobe license agreement accompanying it. 8 | */ 9 | 10 | // Source for ONNX model binaries 11 | const MODEL_BASE_URL = "https://cc-assets.netlify.app/watermarking/trustmark-models/"; 12 | 13 | // List all watermark models for decoding 14 | const modelConfigs = [ 15 | { variantcode: 'Q', fname: 'decoder_Q', sessionVar: 'session_wmarkQ', resolution: 256, squarecrop: false }, 16 | { variantcode: 'P', fname: 'decoder_P', sessionVar: 'session_wmarkP', resolution: 224, squarecrop: true }, 17 | ]; 18 | 19 | const sessions = {}; 20 | let session_resize; 21 | 22 | // Load model immediately 23 | (async () => { 24 | for (const config of modelConfigs) { 25 | let startTime = new Date(); 26 | try { 27 | sessions[config.sessionVar] = await ort.InferenceSession.create(`${MODEL_BASE_URL}${config.fname}.onnx`, { executionProviders: ['webgpu'] }); 28 | let timeElapsed = new Date() - startTime; 29 | console.log(`${config.fname} model loaded in ${timeElapsed / 1000} seconds`); 30 | } catch (error) { 31 | console.error(`Could not load ${config.fname} watermark decoder model`, error); 32 | } 33 | } 34 | let startTime = new Date(); 35 | try { 36 | session_resize = await ort.InferenceSession.create(`${MODEL_BASE_URL}resizer.onnx`, { executionProviders: ['wasm'] }); // cannot use GPU for this due to lack of antialias 37 | let timeElapsed = new Date() - startTime; 38 | console.log(`Image downscaler model loaded in ${timeElapsed / 1000} seconds`); 39 | } 40 | catch (error) { 41 | console.log('Could not load image downscaler model', error); 42 | console.log(error) 43 | } 44 | })(); 45 | 46 | 47 | /* WebGPU will fail silently and intermittently if multiple concurrent inference calls are made 48 | this routine ensures sequential calling from multiple threads. Note this simple JS demo is single threaded. */ 49 | let inferenceLock = false; 50 | 51 | async function safeRunInference(session, feed) { 52 | while (inferenceLock) { 53 | // Wait for any ongoing inference 54 | await new Promise(resolve => setTimeout(resolve, 30)); 55 | } 56 | 57 | inferenceLock = true; // Lock inference 58 | try { 59 | return await session.run(feed); // Run the inference 60 | } catch (error) { 61 | console.error("Inference error:", error); // Log any error 62 | throw error; // Rethrow for further debugging 63 | } finally { 64 | inferenceLock = false; // Unlock after inference 65 | } 66 | } 67 | 68 | 69 | /** 70 | * Converts an image URL to a tensor suitable for processing. 71 | * @param {string} imageUrl - The URL of the image to load. 72 | * @returns {Promise} The processed tensor. 73 | */ 74 | async function loadImageAsTensor(imageUrl) { 75 | const img = new Image(); 76 | img.src = imageUrl; 77 | 78 | return new Promise((resolve, reject) => { 79 | img.onload = () => { 80 | const canvas = document.createElement('canvas'); 81 | const ctx = canvas.getContext('2d'); 82 | canvas.width = img.width; 83 | canvas.height = img.height; 84 | ctx.drawImage(img, 0, 0); 85 | const imgData = ctx.getImageData(0, 0, img.width, img.height); 86 | 87 | const { data, width, height } = imgData; 88 | const totalPixels = width * height; 89 | const imageTensor = new Float32Array(totalPixels * 3); 90 | 91 | let j = 0; 92 | const page = width * height; 93 | const twopage = 2 * page; 94 | 95 | for (let i = 0; i < totalPixels; i++) { 96 | const index = i * 4; 97 | imageTensor[j] = data[index] / 255.0; // Red channel 98 | imageTensor[j + page] = data[index + 1] / 255.0; // Green channel 99 | imageTensor[j + twopage] = data[index + 2] / 255.0; // Blue channel 100 | j++; 101 | } 102 | 103 | resolve(new ort.Tensor('float32', imageTensor, [1, 3, height, width])); 104 | }; 105 | 106 | img.onerror = () => reject("Failed to load image"); 107 | }); 108 | } 109 | 110 | /** 111 | * Computes scale factors for image resizing with precision. 112 | * @param {Array} targetDims - Target dimensions for the image. 113 | * @param {Array} inputDims - Input tensor dimensions. 114 | * @returns {Float32Array} Scale factors as a tensor. 115 | */ 116 | function computeScalesFixed(targetDims, inputDims) { 117 | const [batch, channels, height, width] = inputDims; 118 | const [targetHeight, targetWidth] = targetDims; 119 | 120 | function computeScale(originalSize, targetSize) { 121 | let minScale = targetSize / originalSize; 122 | let maxScale = (targetSize + 1) / originalSize; 123 | let scale; 124 | let adjustedSize; 125 | 126 | const tolerance = 1e-12; 127 | let iterations = 0; 128 | const maxIterations = 100; 129 | 130 | while (iterations < maxIterations) { 131 | scale = (minScale + maxScale) / 2; 132 | adjustedSize = Math.floor(originalSize * scale + tolerance); 133 | 134 | if (adjustedSize < targetSize) { 135 | minScale = scale; 136 | } else if (adjustedSize > targetSize) { 137 | maxScale = scale; 138 | } else { 139 | break; // Found the correct scale 140 | } 141 | 142 | iterations++; 143 | } 144 | 145 | return scale; 146 | } 147 | 148 | const scaleH = computeScale(height, targetHeight); 149 | const scaleW = computeScale(width, targetWidth); 150 | 151 | return new Float32Array([1.0, 1.0, scaleH, scaleW]); 152 | } 153 | 154 | /** 155 | * Resizes the image tensor to a square size suitable for watermark decoding. 156 | * @param {ort.Tensor} inputTensor - The input image tensor. 157 | * @param {number} targetSize - The target size for resizing. 158 | * @returns {Promise} The resized tensor. 159 | */ 160 | async function runResizeModelSquare(inputTensor, targetSize, force_square) { 161 | try { 162 | const inputDims = inputTensor.dims; // Get dimensions of the input tensor 163 | const [batch, channels, height, width] = inputDims; 164 | 165 | // Compute the aspect ratio 166 | const aspectRatio = width / height; 167 | const lscape= (aspectRatio>=1.0); 168 | 169 | let croppedTensor = inputTensor; 170 | let cropWidth = width; 171 | let cropHeight = height; 172 | 173 | // If the aspect ratio is greater than 2.0, we need to crop the center square 174 | if (lscape && (aspectRatio > 2.0 || force_square)) { 175 | cropWidth = height; // Take a square from the width 176 | const offsetX = Math.floor((width - cropWidth) / 2); // Horizontal center crop 177 | croppedTensor = await cropTensor(inputTensor, offsetX, 0, cropWidth, height); 178 | } 179 | 180 | if (!lscape && (aspectRatio < 0.5 || force_square)) { 181 | cropHeight = width; // Take a square from the height 182 | const offsetY = Math.floor((height - cropHeight) / 2); // Vertical center crop 183 | croppedTensor = await cropTensor(inputTensor, 0, offsetY, width, cropHeight); 184 | } 185 | 186 | // After cropping, resize the tensor to the target size 187 | const targetDims = [targetSize, targetSize]; 188 | const scales = computeScalesFixed(targetDims, [batch, channels, cropHeight, cropWidth]); 189 | const scalesTensor = new ort.Tensor('float32', scales, [4]); 190 | 191 | // Prepare the target size tensor 192 | const targetSizeTensor = new ort.Tensor('int64', new BigInt64Array([BigInt(targetSize)]), [1]); 193 | 194 | // Set up the feeds for the model 195 | const feeds = { 196 | 'X': croppedTensor, // Cropped image tensor 197 | 'scales': scalesTensor, // Scales tensor 198 | 'target_size': targetSizeTensor // Dynamic target size tensor 199 | }; 200 | 201 | const results = await session_resize.run(feeds); 202 | return results['Y']; 203 | 204 | } catch (error) { 205 | console.error('Error during resizing:', error); 206 | return null; 207 | } 208 | } 209 | 210 | // Helper function to crop the tensor 211 | async function cropTensor(inputTensor, offsetX, offsetY, cropWidth, cropHeight) { 212 | const [batch, channels, height, width] = inputTensor.dims; 213 | const croppedData = new Float32Array(batch * channels * cropWidth * cropHeight); 214 | const inputData = inputTensor.data; 215 | 216 | let k = 0; 217 | for (let c = 0; c < channels; c++) { 218 | for (let y = 0; y < cropHeight; y++) { 219 | for (let x = 0; x < cropWidth; x++) { 220 | const srcIndex = c * width * height + (y + offsetY) * width + (x + offsetX); 221 | croppedData[k++] = inputData[srcIndex]; 222 | } 223 | } 224 | } 225 | 226 | return new ort.Tensor('float32', croppedData, [batch, channels, cropHeight, cropWidth]); 227 | } 228 | 229 | 230 | /** 231 | * Decodes the watermark from the processed image tensor. 232 | * @param {string} base64Image - Base64 representation of the image. 233 | * @returns {Promise} Decoded watermark data. 234 | */ 235 | async function runwmark(base64Image) { 236 | 237 | let watermarks=[] 238 | let watermarks_present=[]; 239 | try { 240 | 241 | const inputTensor = await loadImageAsTensor(base64Image); 242 | 243 | for (const config of modelConfigs) { 244 | 245 | const session = sessions[config.sessionVar]; 246 | if (!session) { 247 | console.error(`Session for ${config.fname} not loaded, skipping.`); 248 | continue; 249 | } 250 | 251 | 252 | const resizedTensorWM = await runResizeModelSquare(inputTensor, config.resolution, config.squarecrop); 253 | if (!resizedTensorWM) throw new Error("Failed to resize tensor for watermark detection."); 254 | 255 | const feeds = { image: resizedTensorWM }; 256 | 257 | let startTime = new Date(); 258 | const results = await safeRunInference(session, feeds); 259 | 260 | const watermarkFloat = results['output']['cpuData']; 261 | const watermarkBool = watermarkFloat.map((v) => v >= 0); 262 | 263 | const dataObj = DataLayer_Decode(watermarkBool, eccengine, config.variantcode); 264 | console.log(`Watermark model inference in ${(new Date() - startTime)} milliseconds`); 265 | 266 | // Append results to arrays 267 | watermarks.push(dataObj); 268 | watermarks_present.push(dataObj.valid); 269 | } 270 | 271 | } catch (error) { 272 | console.error("Error in watermark decoding:", error); 273 | return { watermark_present: false, watermark: null, schema: null }; 274 | } 275 | 276 | // Get first detected watermark (if many were) 277 | const firstValidIndex = watermarks_present.findIndex(isValid => isValid === true); 278 | let watermark; 279 | let watermark_present=false; 280 | if (firstValidIndex !== -1) { 281 | watermark = watermarks[firstValidIndex]; 282 | 283 | return { 284 | watermark_present: watermark.valid, 285 | watermark: watermark.valid ? watermark.data_binary : null, 286 | schema: watermark.schema, 287 | c2padata: watermark.softBindingInfo, 288 | }; 289 | } 290 | else { 291 | return { 292 | watermark_present: false, 293 | } 294 | } 295 | 296 | 297 | } 298 | 299 | -------------------------------------------------------------------------------- /python/CONFIG.md: -------------------------------------------------------------------------------- 1 | # Configuring TrustMark 2 | 3 | ## Overview 4 | 5 | All watermarking algorithms trade off between three properties: 6 | 7 | - **Capacity (bits)** 8 | - **Robustness (to various transformations)** 9 | - **Visibility (of watermark)** 10 | 11 | This document explains how to configure TrustMark to tune these properties, however the default configuration for TrustMark (variant Q, 100% strength, BCH_5 error correction) is sufficient for most use cases. 12 | 13 | ## Model variant 14 | 15 | TrustMark has four model variants (**B**, **C**, **P**, and **Q**) that may be selected when instantiating TrustMark. All encode/decode calls on the object will use this variant. 16 | 17 | In general, we recommend using **P** or **Q**: 18 | - **P** is useful for creative applications where very high visual quality is required. 19 | - **Q** is a good all-rounder and is the default. 20 | 21 | > **Note:** Images encoded with one model variant cannot be decoded with another. 22 | 23 | | Variant | Typical PSNR | Model Size Enc/Dec (MB) | Description | 24 | |---------|--------------|-------------------------|-----------------------------------------------------------------------------------------------------| 25 | | **Q** | 43-45 | 17/45 | Default (**Q**uality). Good trade-off between robustness and imperceptibility. Uses ResNet-50 decoder. | 26 | | **B** | 43-45 | 17/45 | (**B**eta). Very similar to Q, included mainly for reproducing the paper. Uses ResNet-50 decoder. | 27 | | **C** | 38-39 | 17/21 | (**C**ompact). Uses a ResNet-18 decoder (smaller model size). Slightly lower visual quality. | 28 | | **P** | 48-50 | 16/45 | (**P**erceptual). Very high visual quality and good robustness. ResNet-50 decoder trained with much higher weight on perceptual loss (see paper). | 29 | 30 | ## Watermark strength 31 | 32 | Set the optional `WM_STRENGTH` parameter when encoding (at runtime). 33 | Its default value is `1.0`, and changing it provides a trade-off between **robustness** and **visibility**: 34 | 35 | - Raising its value (for example, to 1.5) improves robustness (so, for example, the watermark survives printing) but increases the likelihood of ripple artifacts. 36 | - Lowering its value (for example, to 0.8) reduces any likelihood of artifacts but compromises on robustness; however it still survives lower noise, screenshotting, or social media. 37 | 38 | For example: 39 | 40 | ```python 41 | encoded_image = tm.encode(input_image, payload="example", WM_STRENGTH=1.5) 42 | ``` 43 | 44 | ## Error correction level 45 | 46 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits. 47 | The data schema implemented in `python/datalayer.py` enables you to choose an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise. 48 | 49 | ### Encoding modes 50 | 51 | Set the error correction level using one of the four encoding modes. 52 | 53 | The following table describes TrustMark's encoding modes: 54 | 55 | | Encoding | Protected payload | Number of bit flips allowed | 56 | |----------|-------------------|-----------------------------| 57 | | `Encoding.BCH_5` | 61 bits (+ 35 ECC bits) | 5 | 58 | | `Encoding.BCH_4` | 68 bits (+ 28 ECC bits) | 4 | 59 | | `Encoding.BCH_3` | 75 bits (+ 21 ECC bits) | 3 | 60 | | `Encoding.BCH_SUPER` | 40 bits (+ 56 ECC bits) | 8 | 61 | 62 | Specify the mode when you instantiate the encoder, as follows: 63 | 64 | ```py 65 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.) 66 | ``` 67 | 68 | Where `` is `BCH_5`, `BCH_4`, `BCH_3`, or `BCH_SUPER`. 69 | 70 | For example: 71 | 72 | ```py 73 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5) 74 | ``` 75 | 76 | The decoder automatically detects the data schema in a watermark, so you can choose the level of robustness that best suits your use case. 77 | 78 | Selecting the model and strength implicitly selects the level of robustness and visibility of the watermark. If you have reduced robustness for lower visibility, you can regain some robustness by increasing error correction (at the cost of payload capacity). Note that even 40 bits gives a key space of around one trillion. 79 | 80 | ## Center cropping 81 | 82 | TrustMark generates residuals at 256 x 256 and then scales/blends them into the original image. Several derivative papers have adopted this universal resolution-scaling trick. 83 | 84 | - If the original image is extremely long/thin (aspect ratio beyond 2:1), the residual watermark will degrade when scaled. 85 | - TrustMark addresses this by automatically center-cropping the image to a square if the aspect ratio exceeds 2.0. For example, for a 1000 x 200 image, only a 200 x 200 region in the center carries the watermark. 86 | - The aspect ratio limit can be overridden via the ASPECT_RATIO_LIM parameter. Setting it to 1.0 always forces center-crop behavior (useful for content platforms that square-crop images). This is the default when using model variant **P**. 87 | 88 | ## Do not concentrate watermarks 89 | 90 | Visual quality is often measured via **PSNR** (Peak Signal-to-Noise Ratio), but PSNR does not perfectly correlate with human perception of watermark visibility. 91 | 92 | Some derivative works have tried to improve PSNR by "zero padding" or concentrating the watermark into only the central region, effectively altering fewer pixels and artificially raising the PSNR score. However, this approach **can increase human-visible artifacts** in the concentrated area. 93 | 94 | **Parameter:** CONCENTRATE_WM_REGION 95 | - Default is 100% (no zero padding). 96 | - If you set, for example, 80% zero padding, you might inflate PSNR by ~5 dB. 97 | - At 50% zero padding, PSNR might inflate by ~10 dB. 98 | - In some extreme cases, PSNR can reach 55–60 dB but at the cost of noticeable artifacts in that smaller region. 99 | 100 | **In summary**: While the functionality exists for comparison purposes, it’s not recommended for production. High concentration setting yields high PSNR but paradoxically more visible artifacts in the watermarked area. 101 | -------------------------------------------------------------------------------- /python/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Adobe 2 | All Rights Reserved. 3 | 4 | NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | accordance with the terms of the license agreement accompanying it. 6 | 7 | MIT License 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # TrustMark 2 | 3 | This repository contains the official, open source implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2]. 4 | 5 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK. 6 | 7 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA. 8 | 9 | ## Overview 10 | 11 | This repository contains the following directories: 12 | 13 | - `/python`: Python implementation of TrustMark for encoding, decoding and removing image watermarks (using PyTorch). For information on configuring TrustMark in Python, see [Configuring TrustMark](CONFIG.md). 14 | - `/js`: Javascript implementation of TrustMark decoding of image watermarks (using ONNX). For more information, see [TrustMark - JavaScript implementation](../js/README.md). 15 | - `/c2pa`: Python example of how to indicate the presence of a TrustMark watermark in a C2PA manifest. For more information, see [Using TrustMark with C2PA](../c2pa/README.md). 16 | 17 | Model files (**ckpt** PyTorch file for Python and **onnx** ONNX file for JavaScript) are not packaged in this repository due to their size, but are downloaded upon first use. See the code for [URLs and md5 hashes](https://github.com/adobe/trustmark/blob/4ef0dde4abd84d1c6873e7c5024482f849db2c73/python/trustmark/trustmark.py#L30) for a direct download link. 18 | 19 | More information: 20 | - For answers to common questions, see the [FAQ](../FAQ.md). 21 | - For information on configuring TrustMark in Python, see [Configuring TrustMark](CONFIG.md). 22 | 23 | ## Installation 24 | 25 | ### Prerequisite 26 | 27 | You must have Python 3.8.5 or higher to use the TrustMark Python implementation. 28 | 29 | ### Installing from PyPI 30 | 31 | The easiest way to install TrustMark is from the [Python Package Index (PyPI)](https://pypi.org/project/trustmark/) by entering this command: 32 | 33 | ``` 34 | pip install trustmark 35 | ``` 36 | 37 | Alternatively, after you've cloned the repository, you can install from the `python` directory: 38 | 39 | ``` 40 | cd trustmark/python 41 | pip install . 42 | ``` 43 | 44 | ## Quickstart 45 | 46 | To get started quickly, run the `python/test.py` script that provides examples of watermarking several 47 | image files from the `images` directory. 48 | 49 | ### Run the example 50 | 51 | Run the example as follows: 52 | 53 | ```sh 54 | cd trustmark/python 55 | python test.py 56 | ``` 57 | 58 | You'll see output like this: 59 | 60 | ``` 61 | Initializing TrustMark watermarking with ECC using [cpu] 62 | Extracted secret: 1000000100001110000010010001011110010001011000100000100110110 (schema 1) 63 | PSNR = 50.357909 64 | No secret after removal 65 | ``` 66 | 67 | ### Example script 68 | 69 | The `python/test.py` script provides examples of watermarking a JPEG photo, a JPEG GenAI image, and an RGBA PNG image. The example uses TrustMark variant Q to encode the word `mysecret` in ASCII7 encoding into the image `ufo_240.jpg` which is then decoded, and then removed from the image. 70 | 71 | ```python 72 | from trustmark import TrustMark 73 | from PIL import Image 74 | 75 | # init 76 | tm=TrustMark(verbose=True, model_type='Q') # or try P 77 | 78 | # encoding example 79 | cover = Image.open('images/ufo_240.jpg').convert('RGB') 80 | tm.encode(cover, 'mysecret').save('ufo_240_Q.png') 81 | 82 | # decoding example 83 | cover = Image.open('images/ufo_240_Q.png').convert('RGB') 84 | wm_secret, wm_present, wm_schema = tm.decode(cover) 85 | 86 | if wm_present: 87 | print(f'Extracted secret: {wm_secret}') 88 | else: 89 | print('No watermark detected') 90 | 91 | # removal example 92 | stego = Image.open('images/ufo_240_Q.png').convert('RGB') 93 | im_recover = tm.remove_watermark(stego) 94 | im_recover.save('images/recovered.png') 95 | ``` 96 | 97 | ## GPU setup 98 | 99 | TrustMark runs well on CPU hardware. 100 | 101 | To leverage GPU compute for the PyTorch implementation on Ubuntu Linux, first install Conda, then use the following commands to install: 102 | 103 | ```sh 104 | conda create --name trustmark python=3.10 105 | conda activate trustmark 106 | conda install pytorch cudatoolkit=12.8 -c pytorch -c conda-forge 107 | pip install torch==2.1.2 torchvision==0.16.2 -f https://download.pytorch.org/whl/torch_stable.html 108 | pip install . 109 | ``` 110 | 111 | For the JavaScript implementation, a Chromium browser automatically uses WebGPU, if available. 112 | 113 | ## Data schema 114 | 115 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits. 116 | You can configure an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise. 117 | 118 | In payload encoding, the version bits comprise two reserved (unused) bits, and two bits encoding an integer value 0-3 that specifies the data schema as follows: 119 | - 0: BCH_SUPER 120 | - 1: BCH_5 121 | - 2: BCH_4 122 | - 3: BCH_3 123 | 124 | For more details and information on configuring the encoding mode in Python, see [Configuring TrustMark](CONFIG.md). 125 | 126 | ## Citation 127 | 128 | If you find this work useful, please cite the repository and/or TrustMark paper as follows: 129 | 130 | ``` 131 | @article{trustmark, 132 | title={Trustmark: Universal Watermarking for Arbitrary Resolution Images}, 133 | author={Bui, Tu and Agarwal, Shruti and Collomosse, John}, 134 | journal = {ArXiv e-prints}, 135 | archivePrefix = "arXiv", 136 | eprint = {2311.18297}, 137 | year = 2023, 138 | month = nov 139 | } 140 | ``` 141 | 142 | ## License 143 | 144 | This package is is distributed under the terms of the [MIT license](https://github.com/adobe/trustmark/blob/main/LICENSE). 145 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "trustmark" 3 | version = "0.8.0" 4 | description = "TrustMark: Universal Watermarking for Arbitrary Resolution Images" 5 | readme = "README.md" 6 | requires-python = ">=3.8.5" 7 | 8 | license = {file = "LICENSE.txt"} 9 | 10 | keywords = ["watermarking", "CAI", "watermark", "Content Authenticity Initiative", "provenance"] # Optional 11 | 12 | authors = [ 13 | {name = "Shruti Agarwal", email = "shragarw@adobe.com" }, 14 | {name = "John Collomosse", email = "collomos@adobe.com" } 15 | ] 16 | 17 | maintainers = [ 18 | {name = "John Collomosse", email = "collomos@adobe.com" } 19 | ] 20 | 21 | 22 | classifiers = [ 23 | "Development Status :: 5 - Production/Stable", 24 | 25 | # Indicate who your project is intended for 26 | "Intended Audience :: Science/Research", 27 | 28 | "License :: OSI Approved :: MIT License", 29 | 30 | "Programming Language :: Python :: 3" 31 | ] 32 | 33 | dependencies = [ 34 | "omegaconf>=2.1", 35 | "pathlib>=1.0.1", 36 | "numpy>=1.20.0,<2.0.0", 37 | "torch>=2.1.2", 38 | "torchvision>=0.16.2", 39 | "lightning>=2.0", 40 | "six>=1.9", 41 | "einops>=0.4.0" 42 | ] 43 | 44 | [build-system] 45 | requires = ["setuptools>=43.0.0", "wheel"] 46 | build-backend = "setuptools.build_meta" 47 | 48 | [project.urls] 49 | Repository = "https://github.com/adobe/trustmark.git" 50 | 51 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.0,<2.0.0 2 | pathlib>=1.0.1 3 | torch>=2.1.2 4 | torchvision>=0.16.2 5 | lightning>=2.0 6 | omegaconf>=2.1 7 | six>=1.9 8 | einops>=0.4.0 9 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | from setuptools import setup, find_packages 9 | 10 | with open("README.md", "r") as f: 11 | long_description = f.read() 12 | 13 | setup(name='trustmark', 14 | version='0.8.0', 15 | python_requires='>=3.8.5', 16 | description='High fidelty image watermarking for the Content Authenticity Initiative (CAI)', 17 | url='https://github.com/adobe/trustmark', 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | author='John Collomosse', 21 | author_email='collomos@adobe.com', 22 | license='MIT License', 23 | packages=['trustmark','trustmark.KBNet'], 24 | package_data={'trustmark': ['**/*.yaml','**/*.ckpt','**/*.md']}, 25 | include_package_data = True, 26 | install_requires=['omegaconf>=2.1', 27 | 'pathlib>=1.0.1', 28 | 'numpy>=1.20.0,<2.0.0', 29 | 'torch>=2.1.2', 30 | 'torchvision>=0.16.2', 31 | 'lightning>=2.0', 32 | 'six>=1.9', 33 | 'einops>=0.4.0' 34 | ], 35 | 36 | classifiers=[ 37 | 'Development Status :: 5 - Production/Stable', 38 | 'Intended Audience :: Science/Research', 39 | 'Operating System :: OS Independent', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Programming Language :: Python :: 3',],) 42 | -------------------------------------------------------------------------------- /python/test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | 9 | from trustmark import TrustMark 10 | from PIL import Image 11 | from pathlib import Path 12 | import math,random 13 | import numpy as np 14 | 15 | 16 | #EXAMPLE_FILE = '../images/ufo_240.jpg' # JPEG example 17 | #EXAMPLE_FILE = '../images/ghost.png' # PNG RGBA example 18 | EXAMPLE_FILE = '../images/ripley.jpg' # JPEG example 19 | 20 | # Available modes: Q=balance, P=high visual quality, C=compact decoder, B=base from paper 21 | MODE='P' 22 | tm=TrustMark(verbose=True, model_type=MODE, encoding_type=TrustMark.Encoding.BCH_5) 23 | 24 | # encoding example 25 | cover = Image.open(EXAMPLE_FILE) 26 | rgb=cover.convert('RGB') 27 | has_alpha=cover.mode== 'RGBA' 28 | if (has_alpha): 29 | alpha=cover.split()[-1] 30 | 31 | random.seed(1234) 32 | capacity=tm.schemaCapacity() 33 | bitstring=''.join([random.choice(['0', '1']) for _ in range(capacity)]) 34 | encoded=tm.encode(rgb, bitstring, MODE='binary') 35 | 36 | if (has_alpha): 37 | encoded.putalpha(alpha) 38 | outfile=Path(EXAMPLE_FILE).stem+'_'+MODE+'.png' 39 | encoded.save(outfile, exif=cover.info.get('exif'), icc_profile=cover.info.get('icc_profile'), dpi=cover.info.get('dpi')) 40 | 41 | # decoding example 42 | stego = Image.open(outfile).convert('RGB') 43 | wm_secret, wm_present, wm_schema = tm.decode(stego, MODE='binary') 44 | if wm_present: 45 | print(f'Extracted secret: {wm_secret} (schema {wm_schema})') 46 | else: 47 | print('No watermark detected') 48 | 49 | # psnr (quality, higher is better) 50 | mse = np.mean(np.square(np.subtract(np.asarray(stego).astype(np.int16), np.asarray(rgb).astype(np.int16)))) 51 | if mse > 0: 52 | PIXEL_MAX = 255.0 53 | psnr= 20 * math.log10(PIXEL_MAX) - 10 * math.log10(mse) 54 | print('PSNR = %f' % psnr) 55 | 56 | # removal 57 | stego = Image.open(outfile) 58 | rgb=stego.convert('RGB') 59 | has_alpha=stego.mode== 'RGBA' 60 | if (has_alpha): 61 | alpha=stego.split()[-1] 62 | im_recover = tm.remove_watermark(rgb) 63 | wm_secret, wm_present, wm_schema = tm.decode(im_recover) 64 | if wm_present: 65 | print(f'Extracted secret: {wm_secret} (schema {wm_schema})') 66 | else: 67 | print('No secret after removal') 68 | if (has_alpha): 69 | im_recover.putalpha(alpha) 70 | im_recover.save('recovered.png', exif=stego.info.get('exif'), icc_profile=stego.info.get('icc_profile'), dpi=stego.info.get('dpi')) 71 | 72 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/README.md: -------------------------------------------------------------------------------- 1 | Adapted from https://github.com/zhangyi-3/KBNet 2 | 3 | MIT License 4 | 5 | Copyright (c) 2022 Zhang Yi 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/arch_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import math 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | from torch.nn import init as init 13 | from torch.nn.modules.batchnorm import _BatchNorm 14 | 15 | 16 | @torch.no_grad() 17 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 18 | """Initialize network weights. 19 | 20 | Args: 21 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 22 | scale (float): Scale initialized weights, especially for residual 23 | blocks. Default: 1. 24 | bias_fill (float): The value to fill bias. Default: 0 25 | kwargs (dict): Other arguments for initialization function. 26 | """ 27 | if not isinstance(module_list, list): 28 | module_list = [module_list] 29 | for module in module_list: 30 | for m in module.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | init.kaiming_normal_(m.weight, **kwargs) 33 | m.weight.data *= scale 34 | if m.bias is not None: 35 | m.bias.data.fill_(bias_fill) 36 | elif isinstance(m, nn.Linear): 37 | init.kaiming_normal_(m.weight, **kwargs) 38 | m.weight.data *= scale 39 | if m.bias is not None: 40 | m.bias.data.fill_(bias_fill) 41 | elif isinstance(m, _BatchNorm): 42 | init.constant_(m.weight, 1) 43 | if m.bias is not None: 44 | m.bias.data.fill_(bias_fill) 45 | 46 | 47 | def make_layer(basic_block, num_basic_block, **kwarg): 48 | """Make layers by stacking the same blocks. 49 | 50 | Args: 51 | basic_block (nn.module): nn.module class for basic block. 52 | num_basic_block (int): number of blocks. 53 | 54 | Returns: 55 | nn.Sequential: Stacked blocks in nn.Sequential. 56 | """ 57 | layers = [] 58 | for _ in range(num_basic_block): 59 | layers.append(basic_block(**kwarg)) 60 | return nn.Sequential(*layers) 61 | 62 | 63 | class ResidualBlockNoBN(nn.Module): 64 | """Residual block without BN. 65 | 66 | It has a style of: 67 | ---Conv-ReLU-Conv-+- 68 | |________________| 69 | 70 | Args: 71 | num_feat (int): Channel number of intermediate features. 72 | Default: 64. 73 | res_scale (float): Residual scale. Default: 1. 74 | pytorch_init (bool): If set to True, use pytorch default init, 75 | otherwise, use default_init_weights. Default: False. 76 | """ 77 | 78 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 79 | super(ResidualBlockNoBN, self).__init__() 80 | self.res_scale = res_scale 81 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 82 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 83 | self.relu = nn.ReLU(inplace=True) 84 | 85 | if not pytorch_init: 86 | default_init_weights([self.conv1, self.conv2], 0.1) 87 | 88 | def forward(self, x): 89 | identity = x 90 | out = self.conv2(self.relu(self.conv1(x))) 91 | return identity + out * self.res_scale 92 | 93 | 94 | class Upsample(nn.Sequential): 95 | """Upsample module. 96 | 97 | Args: 98 | scale (int): Scale factor. Supported scales: 2^n and 3. 99 | num_feat (int): Channel number of intermediate features. 100 | """ 101 | 102 | def __init__(self, scale, num_feat): 103 | m = [] 104 | if (scale & (scale - 1)) == 0: # scale = 2^n 105 | for _ in range(int(math.log(scale, 2))): 106 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 107 | m.append(nn.PixelShuffle(2)) 108 | elif scale == 3: 109 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 110 | m.append(nn.PixelShuffle(3)) 111 | else: 112 | raise ValueError(f'scale {scale} is not supported. ' 113 | 'Supported scales: 2^n and 3.') 114 | super(Upsample, self).__init__(*m) 115 | 116 | 117 | def flow_warp(x, 118 | flow, 119 | interp_mode='bilinear', 120 | padding_mode='zeros', 121 | align_corners=True): 122 | """Warp an image or feature map with optical flow. 123 | 124 | Args: 125 | x (Tensor): Tensor with size (n, c, h, w). 126 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 127 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 128 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 129 | Default: 'zeros'. 130 | align_corners (bool): Before pytorch 1.3, the default value is 131 | align_corners=True. After pytorch 1.3, the default value is 132 | align_corners=False. Here, we use the True as default. 133 | 134 | Returns: 135 | Tensor: Warped image or feature map. 136 | """ 137 | assert x.size()[-2:] == flow.size()[1:3] 138 | _, _, h, w = x.size() 139 | # create mesh grid 140 | grid_y, grid_x = torch.meshgrid( 141 | torch.arange(0, h).type_as(x), 142 | torch.arange(0, w).type_as(x)) 143 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 144 | grid.requires_grad = False 145 | 146 | vgrid = grid + flow 147 | # scale grid to [-1,1] 148 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 149 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 150 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 151 | output = F.grid_sample( 152 | x, 153 | vgrid_scaled, 154 | mode=interp_mode, 155 | padding_mode=padding_mode, 156 | align_corners=align_corners) 157 | 158 | # TODO, what if align_corners=False 159 | return output 160 | 161 | 162 | def resize_flow(flow, 163 | size_type, 164 | sizes, 165 | interp_mode='bilinear', 166 | align_corners=False): 167 | """Resize a flow according to ratio or shape. 168 | 169 | Args: 170 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 171 | size_type (str): 'ratio' or 'shape'. 172 | sizes (list[int | float]): the ratio for resizing or the final output 173 | shape. 174 | 1) The order of ratio should be [ratio_h, ratio_w]. For 175 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 176 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 177 | ratio > 1.0). 178 | 2) The order of output_size should be [out_h, out_w]. 179 | interp_mode (str): The mode of interpolation for resizing. 180 | Default: 'bilinear'. 181 | align_corners (bool): Whether align corners. Default: False. 182 | 183 | Returns: 184 | Tensor: Resized flow. 185 | """ 186 | _, _, flow_h, flow_w = flow.size() 187 | if size_type == 'ratio': 188 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 189 | elif size_type == 'shape': 190 | output_h, output_w = sizes[0], sizes[1] 191 | else: 192 | raise ValueError( 193 | f'Size type should be ratio or shape, but got type {size_type}.') 194 | 195 | input_flow = flow.clone() 196 | ratio_h = output_h / flow_h 197 | ratio_w = output_w / flow_w 198 | input_flow[:, 0, :, :] *= ratio_w 199 | input_flow[:, 1, :, :] *= ratio_h 200 | resized_flow = F.interpolate( 201 | input=input_flow, 202 | size=(output_h, output_w), 203 | mode=interp_mode, 204 | align_corners=align_corners) 205 | return resized_flow 206 | 207 | 208 | # TODO: may write a cpp file 209 | def pixel_unshuffle(x, scale): 210 | """ Pixel unshuffle. 211 | 212 | Args: 213 | x (Tensor): Input feature with shape (b, c, hh, hw). 214 | scale (int): Downsample ratio. 215 | 216 | Returns: 217 | Tensor: the pixel unshuffled feature. 218 | """ 219 | b, c, hh, hw = x.size() 220 | out_channel = c * (scale ** 2) 221 | assert hh % scale == 0 and hw % scale == 0 222 | h = hh // scale 223 | w = hw // scale 224 | x_view = x.view(b, c, h, scale, w, scale) 225 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 226 | 227 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/kb_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class LayerNormFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, x, weight, bias, eps): 16 | ctx.eps = eps 17 | N, C, H, W = x.size() 18 | mu = x.mean(1, keepdim=True) 19 | var = (x - mu).pow(2).mean(1, keepdim=True) 20 | # print('mu, var', mu.mean(), var.mean()) 21 | # d.append([mu.mean(), var.mean()]) 22 | y = (x - mu) / (var + eps).sqrt() 23 | weight, bias, y = weight.contiguous(), bias.contiguous(), y.contiguous() # avoid cuda error 24 | ctx.save_for_backward(y, var, weight) 25 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 26 | return y 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | eps = ctx.eps 31 | 32 | N, C, H, W = grad_output.size() 33 | # y, var, weight = ctx.saved_variables 34 | y, var, weight = ctx.saved_tensors 35 | g = grad_output * weight.view(1, C, 1, 1) 36 | mean_g = g.mean(dim=1, keepdim=True) 37 | 38 | mean_gy = (g * y).mean(dim=1, keepdim=True) 39 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 40 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 41 | dim=0), None 42 | 43 | 44 | class LayerNorm2d(nn.Module): 45 | 46 | def __init__(self, channels, eps=1e-6, requires_grad=True): 47 | super(LayerNorm2d, self).__init__() 48 | self.register_parameter('weight', nn.Parameter(torch.ones(channels), requires_grad=requires_grad)) 49 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels), requires_grad=requires_grad)) 50 | self.eps = eps 51 | 52 | def forward(self, x): 53 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 54 | 55 | 56 | class SimpleGate(nn.Module): 57 | def forward(self, x): 58 | x1, x2 = x.chunk(2, dim=1) 59 | return x1 * x2 60 | 61 | 62 | class KBAFunction(torch.autograd.Function): 63 | 64 | @staticmethod 65 | def forward(ctx, x, att, selfk, selfg, selfb, selfw): 66 | B, nset, H, W = att.shape 67 | KK = selfk ** 2 68 | selfc = x.shape[1] 69 | 70 | att = att.reshape(B, nset, H * W).transpose(-2, -1) 71 | 72 | ctx.selfk, ctx.selfg, ctx.selfc, ctx.KK, ctx.nset = selfk, selfg, selfc, KK, nset 73 | ctx.x, ctx.att, ctx.selfb, ctx.selfw = x, att, selfb, selfw 74 | 75 | bias = att @ selfb 76 | attk = att @ selfw 77 | 78 | uf = torch.nn.functional.unfold(x, kernel_size=selfk, padding=selfk // 2) 79 | 80 | # for unfold att / less memory cost 81 | uf = uf.reshape(B, selfg, selfc // selfg * KK, H * W).permute(0, 3, 1, 2) 82 | attk = attk.reshape(B, H * W, selfg, selfc // selfg, selfc // selfg * KK) 83 | 84 | x = attk @ uf.unsqueeze(-1) # 85 | del attk, uf 86 | x = x.squeeze(-1).reshape(B, H * W, selfc) + bias 87 | x = x.transpose(-1, -2).reshape(B, selfc, H, W) 88 | return x 89 | 90 | @staticmethod 91 | def backward(ctx, grad_output): 92 | x, att, selfb, selfw = ctx.x, ctx.att, ctx.selfb, ctx.selfw 93 | selfk, selfg, selfc, KK, nset = ctx.selfk, ctx.selfg, ctx.selfc, ctx.KK, ctx.nset 94 | 95 | B, selfc, H, W = grad_output.size() 96 | 97 | dbias = grad_output.reshape(B, selfc, H * W).transpose(-1, -2) 98 | 99 | dselfb = att.transpose(-2, -1) @ dbias 100 | datt = dbias @ selfb.transpose(-2, -1) 101 | 102 | attk = att @ selfw 103 | uf = F.unfold(x, kernel_size=selfk, padding=selfk // 2) 104 | # for unfold att / less memory cost 105 | uf = uf.reshape(B, selfg, selfc // selfg * KK, H * W).permute(0, 3, 1, 2) 106 | attk = attk.reshape(B, H * W, selfg, selfc // selfg, selfc // selfg * KK) 107 | 108 | dx = dbias.view(B, H * W, selfg, selfc // selfg, 1) 109 | 110 | dattk = dx @ uf.view(B, H * W, selfg, 1, selfc // selfg * KK) 111 | duf = attk.transpose(-2, -1) @ dx 112 | del attk, uf 113 | 114 | dattk = dattk.view(B, H * W, -1) 115 | datt += dattk @ selfw.transpose(-2, -1) 116 | dselfw = att.transpose(-2, -1) @ dattk 117 | 118 | duf = duf.permute(0, 2, 3, 4, 1).view(B, -1, H * W) 119 | dx = F.fold(duf, output_size=(H, W), kernel_size=selfk, padding=selfk // 2) 120 | 121 | datt = datt.transpose(-1, -2).view(B, nset, H, W) 122 | 123 | return dx, datt, None, None, dselfb, dselfw 124 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/kbnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import math 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as thf 12 | try: 13 | import lightning as pl 14 | except ImportError: 15 | import pytorch_lightning as pl 16 | import einops 17 | import kornia 18 | import numpy as np 19 | import torchvision 20 | import importlib 21 | from torchmetrics.functional import peak_signal_noise_ratio 22 | from contextlib import contextmanager 23 | from omegaconf import OmegaConf 24 | 25 | 26 | class WMRemoverKBNet(pl.LightningModule): 27 | def __init__(self, 28 | cover_key, 29 | secret_key, 30 | secret_embedder_config_path, 31 | secret_embedder_ckpt_path, 32 | denoise_config, 33 | grad_clip, 34 | ckpt_path="__none__", 35 | ): 36 | super().__init__() 37 | self.automatic_optimization = False # for GAN training 38 | self.cover_key = cover_key 39 | self.secret_key = secret_key 40 | self.grad_clip = grad_clip 41 | 42 | secret_embedder_config = OmegaConf.load(secret_embedder_config_path).model 43 | secret_embedder_config.params.ckpt_path = secret_embedder_ckpt_path 44 | self.secret_len = secret_embedder_config.params.secret_len 45 | 46 | self.secret_embedder = instantiate_from_config(secret_embedder_config).eval() 47 | for p in self.secret_embedder.parameters(): 48 | p.requires_grad = False 49 | 50 | self.denoise = instantiate_from_config(denoise_config) 51 | if ckpt_path != "__none__": 52 | self.init_from_ckpt(ckpt_path, ignore_keys=[]) 53 | 54 | def init_from_ckpt(self, path, ignore_keys=[]): 55 | sd = torch.load(path, map_location="cpu")["state_dict"] 56 | keys = list(sd.keys()) 57 | for k in keys: 58 | for ik in ignore_keys: 59 | if k.startswith(ik): 60 | print("Deleting key {} from state_dict.".format(k)) 61 | del sd[k] 62 | self.load_state_dict(sd, strict=False) 63 | print(f"Restored from {path}") 64 | 65 | @torch.no_grad() 66 | def get_input(self, batch, bs=None): 67 | image = batch[self.cover_key] 68 | secret = batch[self.secret_key] 69 | if bs is not None: 70 | image = image[:bs] 71 | secret = secret[:bs] 72 | else: 73 | bs = image.shape[0] 74 | # encode image 1st stage 75 | image = einops.rearrange(image, "b h w c -> b c h w").contiguous() 76 | stego = self.secret_embedder(image, secret)[0] 77 | out = [stego, image, secret] 78 | return out 79 | 80 | def forward(self, x): 81 | return torch.clamp(self.denoise(x), -1, 1) 82 | 83 | def shared_step(self, batch, batch_idx): 84 | is_training = self.training 85 | x, y, s = self.get_input(batch) 86 | if is_training: 87 | opt_g = self.optimizers() 88 | x_denoised = self(x) 89 | loss = torch.abs(x_denoised - y).mean() 90 | s_pred = self.secret_embedder.decoder(x_denoised) 91 | loss_dict = {} 92 | loss_dict['total_loss'] = loss 93 | loss_dict['bit_acc'] = ((torch.sigmoid(s_pred.detach()) > 0.5).float() == s).float().mean() 94 | if is_training: 95 | self.manual_backward(loss) 96 | if self.grad_clip: 97 | # torch.nn.utils.clip_grad_norm_(self.denoise.parameters(), 0.01) 98 | self.clip_gradients(opt_g, gradient_clip_val=0.01, gradient_clip_algorithm="norm") 99 | opt_g.step() 100 | opt_g.zero_grad() 101 | 102 | loss_dict['psnr_denoise'] = peak_signal_noise_ratio(x_denoised.detach(), y.detach(), data_range=2.0) 103 | loss_dict['psnr_stego'] = peak_signal_noise_ratio(x.detach(), y.detach(), data_range=2.0) 104 | 105 | return loss, loss_dict 106 | 107 | def training_step(self, batch, batch_idx): 108 | loss, loss_dict = self.shared_step(batch, batch_idx) 109 | # logging 110 | loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} 111 | self.log_dict(loss_dict, prog_bar=True, 112 | logger=True, on_step=True, on_epoch=True) 113 | 114 | self.log("global_step", float(self.global_step), 115 | prog_bar=True, logger=True, on_step=True, on_epoch=False) 116 | # if self.use_scheduler: 117 | # lr = self.optimizers().param_groups[0]['lr'] 118 | # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) 119 | sch = self.lr_schedulers() 120 | self.log('lr_abs', sch.get_lr()[0], prog_bar=True, logger=True, on_step=True, on_epoch=False) 121 | sch.step() 122 | 123 | # return loss 124 | 125 | @torch.no_grad() 126 | def validation_step(self, batch, batch_idx): 127 | _, loss_dict_no_ema = self.shared_step(batch, batch_idx) 128 | loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} 129 | self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) 130 | 131 | @torch.no_grad() 132 | def log_images(self, batch, fixed_input=False, **kwargs): 133 | log = dict() 134 | x, y, s = self.get_input(batch) 135 | x_denoised = self(x) 136 | log['clean'] = y 137 | log['stego'] = x 138 | log['denoised'] = x_denoised 139 | residual = x_denoised - y 140 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 141 | return log 142 | 143 | def configure_optimizers(self): 144 | lr = self.learning_rate 145 | params_g = list(self.denoise.parameters()) 146 | opt_g = torch.optim.AdamW(params_g, lr=lr, weight_decay=1e-4, betas=(0.9, 0.999)) 147 | lr_sch = lr_scheduler.CosineAnnealingRestartCyclicLR( 148 | opt_g, periods=[92000, 208000], restart_weights= [1,1], eta_mins=[0.0003,0.000001]) 149 | return [opt_g], [lr_sch] 150 | 151 | 152 | 153 | def get_obj_from_str(string, reload=False): 154 | module, cls = string.rsplit(".", 1) 155 | if reload: 156 | module_imp = importlib.import_module(module) 157 | importlib.reload(module_imp) 158 | return getattr(importlib.import_module(module, package=None), cls) 159 | 160 | 161 | 162 | def instantiate_from_config(config): 163 | if not "target" in config: 164 | if config == '__is_first_stage__': 165 | return None 166 | elif config == "__is_unconditional__": 167 | return None 168 | raise KeyError("Expected key `target` to instantiate.") 169 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 170 | 171 | 172 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/kbnet_l_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | 15 | from einops import rearrange 16 | 17 | from .kb_utils import KBAFunction 18 | from .kb_utils import LayerNorm2d, SimpleGate 19 | 20 | 21 | class Downsample(nn.Module): 22 | def __init__(self, n_feat): 23 | super(Downsample, self).__init__() 24 | 25 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), 26 | nn.PixelUnshuffle(2)) 27 | 28 | def forward(self, x): 29 | return self.body(x) 30 | 31 | 32 | class Upsample(nn.Module): 33 | def __init__(self, n_feat): 34 | super(Upsample, self).__init__() 35 | 36 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), 37 | nn.PixelShuffle(2)) 38 | 39 | def forward(self, x): 40 | return self.body(x) 41 | 42 | 43 | class OverlapPatchEmbed(nn.Module): 44 | def __init__(self, in_c=3, embed_dim=48, bias=False): 45 | super(OverlapPatchEmbed, self).__init__() 46 | 47 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) 48 | 49 | def forward(self, x): 50 | x = self.proj(x) 51 | return x 52 | 53 | 54 | class TransAttention(nn.Module): 55 | def __init__(self, dim, num_heads, bias): 56 | super(TransAttention, self).__init__() 57 | self.num_heads = num_heads 58 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 59 | 60 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) 61 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) 62 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 63 | 64 | def forward(self, x): 65 | b, c, h, w = x.shape 66 | 67 | qkv = self.qkv_dwconv(self.qkv(x)) 68 | q, k, v = qkv.chunk(3, dim=1) 69 | 70 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 71 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 72 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 73 | 74 | q = torch.nn.functional.normalize(q, dim=-1) 75 | k = torch.nn.functional.normalize(k, dim=-1) 76 | 77 | attn = (q @ k.transpose(-2, -1)) * self.temperature 78 | attn = attn.softmax(dim=-1) 79 | 80 | out = (attn @ v) 81 | 82 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 83 | 84 | out = self.project_out(out) 85 | return out 86 | 87 | 88 | class MFF(nn.Module): 89 | def __init__(self, dim, ffn_expansion_factor, bias, act=True, gc=2, nset=32, k=3): 90 | super(MFF, self).__init__() 91 | self.act = act 92 | self.gc = gc 93 | 94 | hidden_features = int(dim * ffn_expansion_factor) 95 | 96 | self.dwconv = nn.Sequential( 97 | nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias), 98 | nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, 99 | groups=hidden_features, bias=bias), 100 | ) 101 | 102 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 103 | 104 | self.sca = nn.Sequential( 105 | nn.AdaptiveAvgPool2d(1), 106 | nn.Conv2d(in_channels=dim, out_channels=hidden_features, kernel_size=1, padding=0, stride=1, 107 | groups=1, bias=True), 108 | ) 109 | self.conv1 = nn.Sequential( 110 | nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias), 111 | nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, 112 | groups=hidden_features, bias=bias), 113 | ) 114 | 115 | c = hidden_features 116 | self.k, self.c = k, c 117 | self.nset = nset 118 | 119 | self.g = c // gc 120 | self.w = nn.Parameter(torch.zeros(1, nset, c * c // self.g * self.k ** 2)) 121 | self.b = nn.Parameter(torch.zeros(1, nset, c)) 122 | self.init_p(self.w, self.b) 123 | interc = min(dim, 24) 124 | # print(c, interc) 125 | self.conv2 = nn.Sequential( 126 | nn.Conv2d(in_channels=dim, out_channels=interc, kernel_size=3, padding=1, stride=1, groups=interc, 127 | bias=True), 128 | SimpleGate(), 129 | nn.Conv2d(interc // 2, self.nset, 1, padding=0, stride=1), 130 | ) 131 | self.conv211 = nn.Conv2d(in_channels=dim, out_channels=self.nset, kernel_size=1) 132 | self.attgamma = nn.Parameter(torch.zeros((1, self.nset, 1, 1)) + 1e-2, requires_grad=True) 133 | self.ga1 = nn.Parameter(torch.zeros((1, hidden_features, 1, 1)) + 1e-2, requires_grad=True) 134 | 135 | def forward(self, x): 136 | sca = self.sca(x) 137 | x1 = self.dwconv(x) 138 | 139 | att = self.conv2(x) * self.attgamma + self.conv211(x) 140 | uf = self.conv1(x) 141 | x2 = self.KBA(uf, att, self.k, self.g, self.b, self.w) * self.ga1 + uf 142 | 143 | x = F.gelu(x1) * x2 if self.act else x1 * x2 144 | x = x * sca 145 | x = self.project_out(x) 146 | return x 147 | 148 | def init_p(self, weight, bias=None): 149 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 150 | if bias is not None: 151 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 152 | bound = 1 / math.sqrt(fan_in) 153 | init.uniform_(bias, -bound, bound) 154 | 155 | def KBA(self, x, att, selfk, selfg, selfb, selfw): 156 | return KBAFunction.apply(x, att, selfk, selfg, selfb, selfw) 157 | 158 | 159 | class KBBlock_l(nn.Module): 160 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias): 161 | super(KBBlock_l, self).__init__() 162 | 163 | self.norm1 = LayerNorm2d(dim) 164 | self.norm2 = LayerNorm2d(dim) 165 | 166 | self.attn = MFF(dim, ffn_expansion_factor, bias) 167 | self.ffn = TransAttention(dim, num_heads, bias) 168 | 169 | def forward(self, x): 170 | x = x + self.attn(self.norm1(x)) 171 | x = x + self.ffn(self.norm2(x)) 172 | return x 173 | 174 | 175 | class KBNet_l(nn.Module): 176 | def __init__(self, inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, 177 | heads=[1, 2, 4, 8], ffn_expansion_factor=1.5, bias=False, 178 | blockname='KBBlock_l'): 179 | super(KBNet_l, self).__init__() 180 | # print('\r** ', blockname, end='') 181 | TransformerBlock = eval(blockname) 182 | 183 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 184 | 185 | self.encoder_level1 = nn.Sequential(*[ 186 | TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in 187 | range(num_blocks[0])]) 188 | 189 | self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 190 | self.encoder_level2 = nn.Sequential(*[ 191 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, 192 | bias=bias) for i in range(num_blocks[1])]) 193 | 194 | self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 195 | self.encoder_level3 = nn.Sequential(*[ 196 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, 197 | bias=bias) for i in range(num_blocks[2])]) 198 | 199 | self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 200 | self.latent = nn.Sequential(*[ 201 | TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, 202 | bias=bias) for i in range(num_blocks[3])]) 203 | 204 | self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 205 | self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) 206 | self.decoder_level3 = nn.Sequential(*[ 207 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, 208 | bias=bias) for i in range(num_blocks[2])]) 209 | 210 | self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 211 | self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) 212 | self.decoder_level2 = nn.Sequential(*[ 213 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, 214 | bias=bias) for i in range(num_blocks[1])]) 215 | 216 | self.up2_1 = Upsample(int(dim * 2 ** 1)) 217 | 218 | self.decoder_level1 = nn.Sequential(*[ 219 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, 220 | bias=bias) for i in range(num_blocks[0])]) 221 | 222 | self.refinement = nn.Sequential(*[ 223 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, 224 | bias=bias) for i in range(num_refinement_blocks)]) 225 | 226 | self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 227 | 228 | def forward(self, inp_img): 229 | inp_enc_level1 = self.patch_embed(inp_img) 230 | out_enc_level1 = self.encoder_level1(inp_enc_level1) 231 | 232 | inp_enc_level2 = self.down1_2(out_enc_level1) 233 | out_enc_level2 = self.encoder_level2(inp_enc_level2) 234 | 235 | inp_enc_level3 = self.down2_3(out_enc_level2) 236 | out_enc_level3 = self.encoder_level3(inp_enc_level3) 237 | 238 | inp_enc_level4 = self.down3_4(out_enc_level3) 239 | latent = self.latent(inp_enc_level4) 240 | 241 | inp_dec_level3 = self.up4_3(latent) 242 | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) 243 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) 244 | out_dec_level3 = self.decoder_level3(inp_dec_level3) 245 | 246 | inp_dec_level2 = self.up3_2(out_dec_level3) 247 | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) 248 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) 249 | out_dec_level2 = self.decoder_level2(inp_dec_level2) 250 | 251 | inp_dec_level1 = self.up2_1(out_dec_level2) 252 | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) 253 | out_dec_level1 = self.decoder_level1(inp_dec_level1) 254 | 255 | out_dec_level1 = self.refinement(out_dec_level1) 256 | 257 | out_dec_level1 = self.output(out_dec_level1) + inp_img 258 | 259 | return out_dec_level1 260 | -------------------------------------------------------------------------------- /python/trustmark/KBNet/kbnet_s_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | 15 | from .kb_utils import KBAFunction 16 | from .kb_utils import LayerNorm2d, SimpleGate 17 | 18 | 19 | class KBBlock_s(nn.Module): 20 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, nset=32, k=3, gc=4, lightweight=False): 21 | super(KBBlock_s, self).__init__() 22 | self.k, self.c = k, c 23 | self.nset = nset 24 | dw_ch = int(c * DW_Expand) 25 | ffn_ch = int(FFN_Expand * c) 26 | 27 | self.g = c // gc 28 | self.w = nn.Parameter(torch.zeros(1, nset, c * c // self.g * self.k ** 2)) 29 | self.b = nn.Parameter(torch.zeros(1, nset, c)) 30 | self.init_p(self.w, self.b) 31 | 32 | self.norm1 = LayerNorm2d(c) 33 | self.norm2 = LayerNorm2d(c) 34 | 35 | self.sca = nn.Sequential( 36 | nn.AdaptiveAvgPool2d(1), 37 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, 38 | groups=1, bias=True), 39 | ) 40 | 41 | if not lightweight: 42 | self.conv11 = nn.Sequential( 43 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, 44 | bias=True), 45 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=5, padding=2, stride=1, groups=c // 4, 46 | bias=True), 47 | ) 48 | else: 49 | self.conv11 = nn.Sequential( 50 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, 51 | bias=True), 52 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, padding=1, stride=1, groups=c, 53 | bias=True), 54 | ) 55 | 56 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, 57 | bias=True) 58 | self.conv21 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, padding=1, stride=1, groups=c, 59 | bias=True) 60 | 61 | interc = min(c, 32) 62 | self.conv2 = nn.Sequential( 63 | nn.Conv2d(in_channels=c, out_channels=interc, kernel_size=3, padding=1, stride=1, groups=interc, 64 | bias=True), 65 | SimpleGate(), 66 | nn.Conv2d(interc // 2, self.nset, 1, padding=0, stride=1), 67 | ) 68 | 69 | self.conv211 = nn.Conv2d(in_channels=c, out_channels=self.nset, kernel_size=1) 70 | 71 | self.conv3 = nn.Conv2d(in_channels=dw_ch // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 72 | groups=1, bias=True) 73 | 74 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_ch, kernel_size=1, padding=0, stride=1, groups=1, 75 | bias=True) 76 | self.conv5 = nn.Conv2d(in_channels=ffn_ch // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 77 | groups=1, bias=True) 78 | 79 | self.dropout1 = nn.Identity() 80 | self.dropout2 = nn.Identity() 81 | 82 | self.ga1 = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True) 83 | self.attgamma = nn.Parameter(torch.zeros((1, self.nset, 1, 1)) + 1e-2, requires_grad=True) 84 | self.sg = SimpleGate() 85 | 86 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True) 87 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True) 88 | 89 | def init_p(self, weight, bias=None): 90 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 91 | if bias is not None: 92 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 93 | bound = 1 / math.sqrt(fan_in) 94 | init.uniform_(bias, -bound, bound) 95 | 96 | def KBA(self, x, att, selfk, selfg, selfb, selfw): 97 | return KBAFunction.apply(x, att, selfk, selfg, selfb, selfw) 98 | 99 | def forward(self, inp): 100 | x = inp 101 | 102 | x = self.norm1(x) 103 | sca = self.sca(x) 104 | x1 = self.conv11(x) 105 | 106 | # KBA module 107 | att = self.conv2(x) * self.attgamma + self.conv211(x) 108 | uf = self.conv21(self.conv1(x)) 109 | x = self.KBA(uf, att, self.k, self.g, self.b, self.w) * self.ga1 + uf 110 | x = x * x1 * sca 111 | 112 | x = self.conv3(x) 113 | x = self.dropout1(x) 114 | y = inp + x * self.beta 115 | 116 | # FFN 117 | x = self.norm2(y) 118 | x = self.conv4(x) 119 | x = self.sg(x) 120 | x = self.conv5(x) 121 | 122 | x = self.dropout2(x) 123 | return y + x * self.gamma 124 | 125 | 126 | class KBNet_s(nn.Module): 127 | def __init__(self, img_channel=3, width=64, middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8], 128 | dec_blk_nums=[2, 2, 2, 2], basicblock='KBBlock_s', lightweight=False, ffn_scale=2): 129 | super().__init__() 130 | basicblock = eval(basicblock) 131 | 132 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, 133 | groups=1, bias=True) 134 | 135 | self.encoders = nn.ModuleList() 136 | self.middle_blks = nn.ModuleList() 137 | self.decoders = nn.ModuleList() 138 | 139 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 140 | groups=1, bias=True) 141 | 142 | self.ups = nn.ModuleList() 143 | self.downs = nn.ModuleList() 144 | 145 | chan = width 146 | for num in enc_blk_nums: 147 | self.encoders.append( 148 | nn.Sequential( 149 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)] 150 | ) 151 | ) 152 | self.downs.append( 153 | nn.Conv2d(chan, 2 * chan, 2, 2) 154 | ) 155 | chan = chan * 2 156 | 157 | self.middle_blks = \ 158 | nn.Sequential( 159 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(middle_blk_num)] 160 | ) 161 | 162 | for num in dec_blk_nums: 163 | self.ups.append( 164 | nn.Sequential( 165 | nn.Conv2d(chan, chan * 2, 1, bias=False), 166 | nn.PixelShuffle(2) 167 | ) 168 | ) 169 | chan = chan // 2 170 | self.decoders.append( 171 | nn.Sequential( 172 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)] 173 | ) 174 | ) 175 | 176 | self.padder_size = 2 ** len(self.encoders) 177 | 178 | def forward(self, inp): 179 | B, C, H, W = inp.shape 180 | inp = self.check_image_size(inp) 181 | x = self.intro(inp) 182 | 183 | encs = [] 184 | 185 | for encoder, down in zip(self.encoders, self.downs): 186 | x = encoder(x) 187 | encs.append(x) 188 | x = down(x) 189 | 190 | x = self.middle_blks(x) 191 | 192 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 193 | x = up(x) 194 | x = x + enc_skip 195 | x = decoder(x) 196 | 197 | x = self.ending(x) 198 | x = x + inp 199 | 200 | return x[:, :, :H, :W] 201 | 202 | def check_image_size(self, x): 203 | _, _, h, w = x.size() 204 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 205 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 206 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 207 | return x 208 | 209 | 210 | class SecretEncoderKBNet(nn.Module): 211 | def __init__(self, resolution=256, img_channel=3, secret_len=100, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 2, 4], 212 | dec_blk_nums=[1, 1, 1, 1], basicblock='KBBlock_s', lightweight=True, ffn_scale=1.5): 213 | super().__init__() 214 | basicblock = eval(basicblock) 215 | assert resolution % 16 == 0, 'resolution must be divisible by 16' 216 | self.secret_len = secret_len 217 | 218 | self.intro = nn.Conv2d(in_channels=img_channel*2, out_channels=width, kernel_size=3, padding=1, stride=1, 219 | groups=1, bias=True) 220 | self.secret_pre = nn.Linear(secret_len, 16*16*img_channel) 221 | self.secret_up = nn.Upsample(scale_factor=(resolution//16, resolution//16)) 222 | 223 | self.encoders = nn.ModuleList() 224 | self.middle_blks = nn.ModuleList() 225 | self.decoders = nn.ModuleList() 226 | 227 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 228 | groups=1, bias=True) 229 | 230 | self.ups = nn.ModuleList() 231 | self.downs = nn.ModuleList() 232 | self.skip = nn.ModuleList() 233 | 234 | chan = width 235 | for num in enc_blk_nums: 236 | self.encoders.append( 237 | nn.Sequential( 238 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)] 239 | ) 240 | ) 241 | self.downs.append( 242 | nn.Conv2d(chan, 2 * chan, 2, 2) 243 | ) 244 | chan = chan * 2 245 | 246 | self.middle_blks = \ 247 | nn.Sequential( 248 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(middle_blk_num)] 249 | ) 250 | 251 | for num in dec_blk_nums: 252 | self.ups.append( 253 | nn.Sequential( 254 | # nn.Conv2d(chan, chan * 2, 1, bias=False), 255 | # nn.PixelShuffle(2) 256 | nn.Upsample(scale_factor=(2,2)), 257 | nn.ZeroPad2d((0, 1, 0, 1)), 258 | nn.Conv2d(chan, chan // 2, 2, 1), 259 | 260 | ) 261 | ) 262 | self.skip.append( 263 | nn.Sequential( 264 | nn.Conv2d(chan, chan//2, 3, 1, 1), 265 | nn.ReLU(inplace=True), 266 | ) 267 | ) 268 | chan = chan // 2 269 | self.decoders.append( 270 | nn.Sequential( 271 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)] 272 | ) 273 | ) 274 | 275 | self.padder_size = 2 ** len(self.encoders) 276 | 277 | def forward(self, image, secret): 278 | H, W = image.shape[-2:] 279 | image = self.check_image_size(image) 280 | secret = F.relu(self.secret_pre(secret)) 281 | secret = secret.view(-1, image.shape[1], 16, 16) 282 | secret = self.secret_up(secret) 283 | imgsec = torch.cat([image, secret], dim=1) # B, 6, 256, 256 284 | x = self.intro(imgsec) 285 | 286 | encs = [] 287 | 288 | for encoder, down in zip(self.encoders, self.downs): 289 | x = encoder(x) 290 | encs.append(x) 291 | x = down(x) 292 | 293 | x = self.middle_blks(x) 294 | # import pdb; pdb.set_trace() 295 | for decoder, up, enc_skip, skip in zip(self.decoders, self.ups, encs[::-1], self.skip): 296 | x = up(x) 297 | # x = x + enc_skip 298 | x = torch.cat([enc_skip, x], dim=1) 299 | x = skip(x) 300 | x = decoder(x) 301 | 302 | x = self.ending(x) # B, 3, 256, 256 303 | # x = x + image 304 | x = torch.tanh(x) 305 | 306 | return x[:, :, :H, :W] 307 | 308 | def check_image_size(self, x): 309 | _, _, h, w = x.size() 310 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 311 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 312 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 313 | return x 314 | 315 | -------------------------------------------------------------------------------- /python/trustmark/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | from .trustmark import TrustMark 8 | -------------------------------------------------------------------------------- /python/trustmark/bchecc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | 9 | from dataclasses import dataclass 10 | from copy import deepcopy 11 | 12 | 13 | class BCH(object): 14 | 15 | @dataclass 16 | class params: 17 | m: int 18 | t: int 19 | poly: int 20 | 21 | @dataclass 22 | class polynomial: 23 | deg: int 24 | 25 | 26 | 27 | ### GALOIS OPERATIONS 28 | 29 | def g_inv(self,a): 30 | return self.ECCstate.exponents[self.ECCstate.n - self.ECCstate.logarithms[a]] 31 | 32 | def g_sqrt(self, a): 33 | if a: 34 | return self.ECCstate.exponents[self.mod(2*self.ECCstate.logarithms[a])] 35 | else: 36 | return 0 37 | 38 | def mod(self, v): 39 | if v0 and b>0): 47 | res=self.mod(self.ECCstate.logarithms[a]+self.ECCstate.logarithms[b]) 48 | return (self.ECCstate.exponents[res]) 49 | else: 50 | return 0 51 | 52 | def g_div(self,a,b): 53 | if a: 54 | return self.ECCstate.exponents[self.mod(self.ECCstate.logarithms[a]+self.ECCstate.n-self.ECCstate.logarithms[b])] 55 | else: 56 | return 0 57 | 58 | def modn(self, v): 59 | n=self.ECCstate.n 60 | while (v>=n): 61 | v -= n 62 | v = (v & n) + (v >> self.ECCstate.m) 63 | return v 64 | 65 | def g_log(self, x): 66 | return self.ECCstate.logarithms[x] 67 | 68 | def a_ilog(self, x): 69 | return self.mod(self.ECCstate.n- self.ECCstate.logarithms[x]) 70 | 71 | def g_pow(self, i): 72 | return self.ECCstate.exponents[self.modn(i)] 73 | 74 | def deg(self, x): 75 | count=0 76 | while (x >> 1): 77 | x = x >> 1 78 | count += 1 79 | return count 80 | 81 | 82 | def ceilop(self, a, b): 83 | return int((a + b - 1) / b) 84 | 85 | def load4bytes(self, data): 86 | w=0 87 | w += data[0] << 24 88 | w += data[1] << 16 89 | w += data[2] << 8 90 | w += data[3] << 0 91 | return w 92 | 93 | 94 | def getroots(self, k, poly): 95 | 96 | roots=[] 97 | 98 | if poly.deg>2: 99 | k=k*8+self.ECCstate.ecc_bits 100 | 101 | rep=[0]*(self.ECCstate.t*2) 102 | d=poly.deg 103 | l=self.ECCstate.n-self.g_log(poly.c[poly.deg]) 104 | for i in range(0,d): 105 | if poly.c[i]: 106 | rep[i]=self.mod(self.g_log(poly.c[i])+l) 107 | else: 108 | rep[i]=-1 109 | 110 | rep[poly.deg]=0 111 | syn0=self.g_div(poly.c[0],poly.c[poly.deg]) 112 | for i in range(self.ECCstate.n-k+1, self.ECCstate.n+1): 113 | syn=syn0 114 | for j in range(1,poly.deg+1): 115 | m=rep[j] 116 | if m>=0: 117 | syn = syn ^ self.g_pow(m+j*i) 118 | if syn==0: 119 | roots.append(self.ECCstate.n-i) 120 | if len(roots)==poly.deg: 121 | break 122 | if len(roots)0): 171 | w=self.load4bytes(recvecc[offset:(offset+4)]) 172 | eccbuf.append(w) 173 | offset+=4 174 | mlen -=1 175 | recvecc=recvecc[offset:] 176 | leftdata=len(recvecc) 177 | if leftdata>0: #pad it to 4 178 | recvecc=recvecc+bytes([0]*(4-leftdata)) 179 | w=self.load4bytes(recvecc) 180 | eccbuf.append(w) 181 | 182 | eccwords=self.ceilop(self.ECCstate.m*self.ECCstate.t, 32) 183 | 184 | sum=0 185 | for i in range(0,eccwords): 186 | self.ECCstate.ecc_buf[i] = self.ECCstate.ecc_buf[i] ^ eccbuf[i] 187 | sum = sum | self.ECCstate.ecc_buf[i] 188 | if sum==0: 189 | return 0 # no bit flips 190 | 191 | 192 | s=self.ECCstate.ecc_bits 193 | t=self.ECCstate.t 194 | syn=[0]*(2*t) 195 | 196 | m= s & 31 197 | 198 | synbuf=self.ECCstate.ecc_buf 199 | 200 | if (m): 201 | synbuf[int(s/32)] = synbuf[int(s/32)] & ~(pow(2,32-m)-1) 202 | 203 | synptr=0 204 | while(s>0 or synptr==0): 205 | poly=synbuf[synptr] 206 | synptr += 1 207 | s-= 32 208 | while (poly): 209 | i=self.deg(poly) 210 | for j in range(0,(2*t),2): 211 | syn[j]=syn[j] ^ self.g_pow((j+1)*(i+s)) 212 | poly = poly ^ pow(2,i) 213 | 214 | 215 | for i in range(0,t): 216 | syn[2*i+1]=self.g_sqrt(syn[i]) 217 | 218 | 219 | n=self.ECCstate.n 220 | t=self.ECCstate.t 221 | pp=-1 222 | pd=1 223 | 224 | pelp=self.polynomial(deg=0) 225 | pelp.deg=0 226 | pelp.c= [0]*(2*t) 227 | pelp.c[0]=1 228 | 229 | elp=self.polynomial(deg=0) 230 | elp.c= [0]*(2*t) 231 | elp.c[0]=1 232 | 233 | d=syn[0] 234 | 235 | elp_copy=self.polynomial(deg=0) 236 | for i in range(0,t): 237 | if (elp.deg>t): 238 | break 239 | if d: 240 | k=2*i-pp 241 | elp_copy=deepcopy(elp) 242 | tmp=self.g_log(d)+n-self.g_log(pd) 243 | for j in range(0,(pelp.deg+1)): 244 | if (pelp.c[j]): 245 | l=self.g_log(pelp.c[j]) 246 | elp.c[j+k]=elp.c[j+k] ^ self.g_pow(tmp+l) 247 | 248 | 249 | tmp=pelp.deg+k 250 | if tmp>elp.deg: 251 | elp.deg=tmp 252 | pelp=deepcopy(elp_copy) 253 | pd=d 254 | pp=2*i 255 | if (i= nbits: 268 | return -1 269 | self.ECCstate.errloc[i]=nbits-1-self.ECCstate.errloc[i] 270 | self.ECCstate.errloc[i]=(self.ECCstate.errloc[i] & ~7) | (7-(self.ECCstate.errloc[i] & 7)) 271 | 272 | 273 | for bitflip in self.ECCstate.errloc: 274 | byte= int (bitflip / 8) 275 | bit = pow(2,(bitflip & 7)) 276 | if bitflip < (len(data)+len(recvecc))*8: 277 | if byte0): 304 | w=self.load4bytes(data[offset:(offset+4)]) 305 | w=w^r[0] 306 | p0=tab0idx+(l+1)*((w>>0) & 0xff) 307 | p1=tab1idx+(l+1)*((w>>8) & 0xff) 308 | p2=tab2idx+(l+1)*((w>>16) & 0xff) 309 | p3=tab3idx+(l+1)*((w>>24) & 0xff) 310 | 311 | for i in range(0,l): 312 | r[i]=r[i+1] ^ self.ECCstate.cyclic_tab[p0+i] ^ self.ECCstate.cyclic_tab[p1+i] ^ self.ECCstate.cyclic_tab[p2+i] ^ self.ECCstate.cyclic_tab[p3+i] 313 | 314 | r[l] = self.ECCstate.cyclic_tab[p0+l]^self.ECCstate.cyclic_tab[p1+l]^self.ECCstate.cyclic_tab[p2+l]^self.ECCstate.cyclic_tab[p3+l]; 315 | mlen -=1 316 | offset +=4 317 | 318 | 319 | data=data[offset:] 320 | leftdata=len(data) 321 | 322 | ecc=r 323 | posn=0 324 | while (leftdata): 325 | tmp=data[posn] 326 | posn += 1 327 | pidx = (l+1)*(((ecc[0] >> 24)^(tmp)) & 0xff) 328 | for i in range(0,l): 329 | ecc[i]=(((ecc[i] << 8)&0xffffffff)|ecc[i+1]>>24)^(self.ECCstate.cyclic_tab[pidx]) 330 | pidx += 1 331 | ecc[l]=((ecc[l] << 8)&0xffffffff)^(self.ECCstate.cyclic_tab[pidx]) 332 | leftdata -= 1 333 | 334 | self.ECCstate.ecc_buf=ecc 335 | eccout=[] 336 | for e in r: 337 | eccout.append((e >> 24) & 0xff) 338 | eccout.append((e >> 16) & 0xff) 339 | eccout.append((e >> 8) & 0xff) 340 | eccout.append((e >> 0) & 0xff) 341 | 342 | eccout=eccout[0:self.ECCstate.ecc_bytes] 343 | 344 | eccbytes=(bytearray(bytes(eccout))) 345 | return eccbytes 346 | 347 | 348 | 349 | def build_cyclic(self, g): 350 | 351 | l=self.ceilop(self.ECCstate.m*self.ECCstate.t, 32) 352 | 353 | plen=self.ceilop(self.ECCstate.ecc_bits+1,32) 354 | ecclen=self.ceilop(self.ECCstate.ecc_bits,32) 355 | 356 | self.ECCstate.cyclic_tab = [0] * 4*256*l 357 | 358 | for i in range(0,256): 359 | for b in range(0,4): 360 | offset= (b*256+i)*l 361 | data = i << 8*b 362 | while (data): 363 | 364 | d=self.deg(data) 365 | data = data ^ (g[0] >> (31-d)) 366 | for j in range(0,ecclen): 367 | if d<31: 368 | hi=(g[j] << (d+1)) & 0xffffffff 369 | else: 370 | hi=0 371 | if j+1 < plen: 372 | lo= g[j+1] >> (31-d) 373 | else: 374 | lo= 0 375 | self.ECCstate.cyclic_tab[j+offset] = self.ECCstate.cyclic_tab[j+offset] ^ (hi | lo) 376 | 377 | 378 | def __init__(self, t, poly): 379 | 380 | tmp = poly; 381 | m = 0; 382 | while (tmp >> 1): 383 | tmp =tmp >> 1 384 | m +=1 385 | 386 | self.ECCstate=self.params(m=m,t=t,poly=poly) 387 | 388 | self.ECCstate.n=pow(2,m)-1 389 | words = self.ceilop(m*t,32) 390 | self.ECCstate.ecc_bytes = self.ceilop(m*t,8) 391 | self.ECCstate.cyclic_tab=[0]*(words*1024) 392 | self.ECCstate.syn=[0]*(2*t) 393 | self.ECCstate.elp=[0]*(t+1) 394 | self.ECCstate.errloc=[0] * t 395 | 396 | 397 | x=1 398 | k=pow(2,self.deg(poly)) 399 | if k != pow(2,self.ECCstate.m): 400 | return -1 401 | 402 | self.ECCstate.exponents=[0]*(1+self.ECCstate.n) 403 | self.ECCstate.logarithms=[0]*(1+self.ECCstate.n) 404 | self.ECCstate.elp_pre=[0]*(1+self.ECCstate.m) 405 | 406 | for i in range(0,self.ECCstate.n): 407 | self.ECCstate.exponents[i]=x 408 | self.ECCstate.logarithms[x]=i 409 | if i and x==1: 410 | return -1 411 | x*= 2 412 | if (x & k): 413 | x=x^poly 414 | 415 | self.ECCstate.logarithms[0]=0 416 | self.ECCstate.exponents[self.ECCstate.n]=1 417 | 418 | 419 | 420 | n=0 421 | g=self.polynomial(deg=0) 422 | g.c=[0]*((m*t)+1) 423 | roots=[0]*(self.ECCstate.n+1) 424 | genpoly=[0]*self.ceilop(m*t+1,32) 425 | 426 | # enum all roots 427 | for i in range(0,t): 428 | r=2*i+1 429 | for j in range(0,m): 430 | roots[r]=1 431 | r=self.mod(2*r) 432 | 433 | # build g(x) 434 | g.deg=0 435 | g.c[0]=1 436 | for i in range(0,self.ECCstate.n): 437 | if roots[i]: 438 | r=self.ECCstate.exponents[i] 439 | g.c[g.deg+1]=1 440 | for j in range(g.deg,0,-1): 441 | g.c[j]=self.g_mul(g.c[j],r)^g.c[j-1] 442 | g.c[0]=self.g_mul(g.c[0],r) 443 | g.deg += 1 444 | 445 | # store 446 | n = g.deg+1 447 | i = 0 448 | 449 | while (n>0) : 450 | 451 | if n>32: 452 | nbits=32 453 | else: 454 | nbits=n 455 | 456 | word=0 457 | for j in range (0,nbits): 458 | if g.c[n-1-j] : 459 | word = word | pow(2,31-j) 460 | genpoly[i]=word 461 | i += 1 462 | n -= nbits 463 | self.ECCstate.ecc_bits=g.deg 464 | 465 | self.build_cyclic(genpoly); 466 | 467 | 468 | sum=0 469 | aexp=0 470 | for i in range(0,m): 471 | for j in range(0,m): 472 | sum = sum ^ self.g_pow(i*pow(2,j)) 473 | if sum: 474 | aexp=self.ECCstate.exponents[i] 475 | break 476 | 477 | x=0 478 | precomp=[0] * 31 479 | remaining=m 480 | 481 | while (x<= self.ECCstate.n and remaining): 482 | y=self.g_sqrt(x)^x 483 | for i in range(0,2): 484 | r=self.g_log(y) 485 | if (y and (r b c h w").contiguous() 213 | n = torch.multinomial(torch.tensor([0.5,0.3,0.2]), 1).item() + 1 214 | stego = image 215 | for i in range(n): 216 | secret = torch.zeros_like(secret).random_(0, 2) 217 | stego = self.secret_embedder(stego, secret)[0] 218 | # stego = self.secret_embedder(image, secret)[0] 219 | out = [stego, image, secret] 220 | return out 221 | 222 | def forward(self, x): 223 | return torch.clamp(self.denoise(x), -1, 1) 224 | 225 | 226 | @torch.no_grad() 227 | def log_images(self, batch, fixed_input=False, **kwargs): 228 | log = dict() 229 | x, y, s = self.get_input(batch) 230 | x_denoised = self(x) 231 | log['clean'] = y 232 | log['stego'] = x 233 | log['denoised'] = x_denoised 234 | residual = x_denoised - y 235 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 236 | return log 237 | 238 | 239 | 240 | class SimpleUnet(nn.Module): 241 | def __init__(self, dim=32) -> None: 242 | super().__init__() 243 | self.conv1 = nn.Conv2d(3, dim, 3, 1, 1) 244 | self.conv2 = nn.Conv2d(dim, dim, 3, 2, 1) 245 | self.conv3 = nn.Conv2d(dim, dim*2, 3, 2, 1) 246 | self.conv4 = nn.Conv2d(dim*2, dim*4, 3, 2, 1) 247 | self.conv5 = nn.Conv2d(dim*4, dim*8, 3, 2, 1) 248 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) 249 | self.up6 = nn.Conv2d(dim*8, dim*4, 2, 1) 250 | self.upsample6 = nn.Upsample(scale_factor=(2, 2)) 251 | self.conv6 = nn.Conv2d(dim*4 + dim*4, dim*4, 3, 1, 1) 252 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) 253 | self.up7 = nn.Conv2d(dim*4, dim*2, 2, 1) 254 | self.upsample7 = nn.Upsample(scale_factor=(2, 2)) 255 | self.conv7 = nn.Conv2d(dim*2 + dim*2, dim*2, 3, 1, 1) 256 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) 257 | self.up8 = nn.Conv2d(dim*2, dim, 2, 1) 258 | self.upsample8 = nn.Upsample(scale_factor=(2, 2)) 259 | self.conv8 = nn.Conv2d(dim+dim, dim, 3, 1, 1) 260 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) 261 | self.up9 = nn.Conv2d(dim, dim, 2, 1) 262 | self.upsample9 = nn.Upsample(scale_factor=(2, 2)) 263 | self.conv9 = nn.Conv2d(dim + dim + 3, dim, 3, 1, 1) 264 | self.conv10 = nn.Conv2d(dim, dim, 3, 1, 1) 265 | self.post = nn.Conv2d(dim, dim//2, 1) 266 | self.silu = nn.SiLU() 267 | self.out = nn.Conv2d(dim//2, 3, 1) 268 | 269 | def forward(self, image): 270 | inputs = image 271 | 272 | conv1 = thf.relu(self.conv1(inputs)) 273 | conv2 = thf.relu(self.conv2(conv1)) 274 | conv3 = thf.relu(self.conv3(conv2)) 275 | conv4 = thf.relu(self.conv4(conv3)) 276 | conv5 = thf.relu(self.conv5(conv4)) 277 | up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5)))) 278 | merge6 = torch.cat([conv4, up6], dim=1) 279 | conv6 = thf.relu(self.conv6(merge6)) 280 | up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6)))) 281 | merge7 = torch.cat([conv3, up7], dim=1) 282 | conv7 = thf.relu(self.conv7(merge7)) 283 | up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7)))) 284 | merge8 = torch.cat([conv2, up8], dim=1) 285 | conv8 = thf.relu(self.conv8(merge8)) 286 | up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8)))) 287 | merge9 = torch.cat([conv1, up9, inputs], dim=1) 288 | conv9 = thf.relu(self.conv9(merge9)) 289 | conv10 = thf.relu(self.conv10(conv9)) 290 | post = self.silu(self.post(conv10)) 291 | out = thf.tanh(self.out(post)) 292 | return out 293 | 294 | 295 | 296 | 297 | def get_obj_from_str(string, reload=False): 298 | module, cls = string.rsplit(".", 1) 299 | if reload: 300 | module_imp = importlib.import_module(module) 301 | importlib.reload(module_imp) 302 | return getattr(importlib.import_module(module, package=None), cls) 303 | 304 | 305 | 306 | def instantiate_from_config(config): 307 | if not "target" in config: 308 | if config == '__is_first_stage__': 309 | return None 310 | elif config == "__is_unconditional__": 311 | return None 312 | raise KeyError("Expected key `target` to instantiate.") 313 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 314 | 315 | 316 | -------------------------------------------------------------------------------- /python/trustmark/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Adobe 2 | # All Rights Reserved. 3 | 4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | # accordance with the terms of the Adobe license agreement accompanying 6 | # it. 7 | 8 | import math 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as thf 12 | try: 13 | import lightning as pl 14 | except ImportError: 15 | import pytorch_lightning as pl 16 | import einops 17 | import numpy as np 18 | import torchvision 19 | import importlib 20 | from torchmetrics.functional import peak_signal_noise_ratio 21 | from contextlib import contextmanager 22 | 23 | 24 | class Identity(nn.Module): 25 | def __init__(self,*args,**kwargs): 26 | super().__init__() 27 | def forward(self, x): 28 | return x 29 | 30 | 31 | class TrustMark_Arch(pl.LightningModule): 32 | def __init__(self, 33 | cover_key, 34 | secret_key, 35 | secret_len, 36 | resolution, 37 | secret_encoder_config, 38 | secret_decoder_config, 39 | discriminator_config, 40 | loss_config, 41 | bit_acc_thresholds=[0.9, 0.95, 0.98], 42 | noise_config='__none__', 43 | ckpt_path="__none__", 44 | lr_scheduler='__none__', 45 | use_ema=False 46 | ): 47 | super().__init__() 48 | self.automatic_optimization = False 49 | self.cover_key = cover_key 50 | self.secret_key = secret_key 51 | secret_encoder_config.params.secret_len = secret_len 52 | secret_decoder_config.params.secret_len = secret_len 53 | secret_encoder_config.params.resolution = resolution 54 | secret_decoder_config.params.resolution = 224 55 | self.encoder = instantiate_from_config(secret_encoder_config) 56 | self.decoder = instantiate_from_config(secret_decoder_config) 57 | self.loss_layer = instantiate_from_config(loss_config) 58 | self.discriminator = instantiate_from_config(discriminator_config) 59 | 60 | if noise_config != '__none__': 61 | self.noise = instantiate_from_config(noise_config) 62 | 63 | self.lr_scheduler = None if lr_scheduler == '__none__' else lr_scheduler 64 | 65 | self.use_ema = use_ema 66 | if self.use_ema: 67 | print('Using EMA') 68 | self.encoder_ema = LitEma(self.encoder) 69 | self.decoder_ema = LitEma(self.decoder) 70 | self.discriminator_ema = LitEma(self.discriminator) 71 | print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()) + list(self.discriminator_ema.buffers()))}.") 72 | 73 | if ckpt_path != "__none__": 74 | self.init_from_ckpt(ckpt_path, ignore_keys=[]) 75 | 76 | # early training phase 77 | self.fixed_img = None 78 | self.fixed_secret = None 79 | self.register_buffer("fixed_input", torch.tensor(True)) 80 | self.register_buffer("update_gen", torch.tensor(False)) # update generator to fool discriminator 81 | self.bit_acc_thresholds = bit_acc_thresholds 82 | self.crop = Identity() 83 | 84 | def init_from_ckpt(self, path, ignore_keys=list()): 85 | sd = torch.load(path, map_location="cpu")["state_dict"] 86 | keys = list(sd.keys()) 87 | for k in keys: 88 | for ik in ignore_keys: 89 | if k.startswith(ik): 90 | print("Deleting key {} from state_dict.".format(k)) 91 | del sd[k] 92 | self.load_state_dict(sd, strict=False) 93 | print(f"Restored from {path}") 94 | 95 | 96 | 97 | @torch.no_grad() 98 | def get_input(self, batch, bs=None): 99 | image = batch[self.cover_key] 100 | secret = batch[self.secret_key] 101 | if bs is not None: 102 | image = image[:bs] 103 | secret = secret[:bs] 104 | else: 105 | bs = image.shape[0] 106 | # encode image 1st stage 107 | image = einops.rearrange(image, "b h w c -> b c h w").contiguous() 108 | 109 | # check if using fixed input (early training phase) 110 | # if self.training and self.fixed_input: 111 | if self.fixed_input: 112 | if self.fixed_img is None: # first iteration 113 | print('[TRAINING] Warmup - using fixed input image for now!') 114 | self.fixed_img = image.detach().clone()[:bs] 115 | self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only 116 | image = self.fixed_img 117 | new_bs = min(secret.shape[0], image.shape[0]) 118 | image, secret = image[:new_bs], secret[:new_bs] 119 | 120 | out = [image, secret] 121 | return out 122 | 123 | def forward(self, cover, secret): 124 | # return a tuple (stego, residual) 125 | enc_out = self.encoder(cover, secret) 126 | if hasattr(self.encoder, 'return_residual') and self.encoder.return_residual: 127 | return cover + enc_out, enc_out 128 | else: 129 | return enc_out, enc_out - cover 130 | 131 | 132 | 133 | @torch.no_grad() 134 | def log_images(self, batch, fixed_input=False, **kwargs): 135 | log = dict() 136 | if fixed_input and self.fixed_img is not None: 137 | x, s = self.fixed_img, self.fixed_secret 138 | else: 139 | x, s = self.get_input(batch) 140 | stego, residual = self(x, s) 141 | if hasattr(self, 'noise') and self.noise.is_activated(): 142 | img_noise = self.noise(stego, self.global_step, p=1.0) 143 | log['noised'] = img_noise 144 | log['input'] = x 145 | log['stego'] = stego 146 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 147 | return log 148 | 149 | 150 | def get_obj_from_str(string, reload=False): 151 | module, cls = string.rsplit(".", 1) 152 | if reload: 153 | module_imp = importlib.import_module(module) 154 | importlib.reload(module_imp) 155 | return getattr(importlib.import_module(module, package=None), cls) 156 | 157 | 158 | 159 | def instantiate_from_config(config): 160 | if not "target" in config: 161 | if config == '__is_first_stage__': 162 | return None 163 | elif config == "__is_unconditional__": 164 | return None 165 | raise KeyError("Expected key `target` to instantiate.") 166 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 167 | 168 | 169 | -------------------------------------------------------------------------------- /python/trustmark/models/README.md: -------------------------------------------------------------------------------- 1 | Models will be fetched to this folder on first use 2 | -------------------------------------------------------------------------------- /rust/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [alias] 2 | xtask = "run --package xtask --" 3 | -------------------------------------------------------------------------------- /rust/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | /models/* 4 | !/models/.gitkeep 5 | 6 | /env/ 7 | 8 | *.py 9 | 10 | .envrc 11 | 12 | .DS_Store 13 | 14 | *.png 15 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trustmark" 3 | description = "A Rust implementation of TrustMark" 4 | version = "0.2.0" 5 | authors = [ 6 | "John Collomosse ", 7 | "Maurice Fisher ", 8 | "Andrew Halle ", 9 | ] 10 | rust-version = "1.75.0" 11 | edition = "2021" 12 | license = "MIT" 13 | keywords = ["trustmark", "cv", "watermark"] 14 | categories = ["multimedia::images"] 15 | repository = "https://github.com/adobe/trustmark" 16 | 17 | [workspace] 18 | members = ["crates/*"] 19 | 20 | [[bench]] 21 | name = "load" 22 | harness = false 23 | 24 | [[bench]] 25 | name = "encode" 26 | harness = false 27 | 28 | [dependencies] 29 | image = "0.25.6" 30 | fast_image_resize = { version = "5.1.4", features = ["image", "rayon"] } 31 | ndarray = "0.16" 32 | ort = "=2.0.0-rc.8" 33 | thiserror = "1" 34 | 35 | [dev-dependencies] 36 | criterion = "0.5" 37 | -------------------------------------------------------------------------------- /rust/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Adobe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rust/README.md: -------------------------------------------------------------------------------- 1 | # TrustMark — Rust implementation 2 | 3 |
4 | 5 | An implementation in Rust of TrustMark watermarking, as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2]. 6 | 7 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK. 8 | 9 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA. 10 | 11 |
12 | 13 | This crate implements a subset of the functionality of the TrustMark Python implementation, including encoding and decoding of watermarks for all variants in binary mode. The Rust implementation provides the same levels of error correction as the Python implementation. 14 | 15 | Text mode watermarks and watermark removal are not implemented. 16 | 17 | Open an issue if there's something in the Python version that want added to this crate! 18 | 19 | ## Quick start 20 | 21 | ### Download models 22 | 23 | In order to encode or decode watermarks, you'll need to fetch the model files. The models are distributed as ONNX files. 24 | 25 | From the workspace root (the `rust/` directory), run: 26 | 27 | ``` 28 | cargo xtask fetch-models 29 | ``` 30 | 31 | This command downloads models to the `models/` directory. You can move them from there as needed. 32 | 33 | ### Run the CLI 34 | 35 | As a first step, you can run the `trustmark-cli` which is defined in this repository. 36 | 37 | From the workspace root, run: 38 | 39 | ```sh 40 | cargo run --release -p trustmark-cli -- -m ./models encode -i ../images/ghost.png -o ../images/encoded.png 41 | cargo run --release -p trustmark-cli -- -m ./models decode -i ../images/encoded.png 42 | ``` 43 | 44 | The argument to the `-m` option is the path to the models downloaded; if you moved them, pass the relative file path as the option value. 45 | 46 | ### Use the library 47 | 48 | Add `trustmark` to your project's `cargo` manifest with: 49 | 50 | ``` 51 | cargo add trustmark 52 | ``` 53 | 54 | A basic example of using `trustmark` is: 55 | 56 | ```rust 57 | use trustmark::{Trustmark, Version, Variant}; 58 | 59 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 60 | let input = image::open("../images/ghost.png").unwrap(); 61 | let output = tm.encode("0010101".to_owned(), input, 0.95); 62 | ``` 63 | 64 | ## Running the benchmarks 65 | 66 | ### Rust benchmarks 67 | 68 | To run the Rust benchmarks, run the following from the workspace root: 69 | 70 | ``` 71 | cargo bench 72 | ``` 73 | 74 | ### Python benchmarks 75 | 76 | To run the Python benchmarks, run the following from the workspace root: 77 | 78 | ``` 79 | benches/load.sh && benches/encode.sh 80 | ``` 81 | -------------------------------------------------------------------------------- /rust/benches/encode.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion}; 9 | use trustmark::{Trustmark, Variant, Version}; 10 | 11 | fn encode_main(c: &mut Criterion) { 12 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 13 | let input = image::open("../images/ufo_240.jpg").unwrap(); 14 | c.bench_function("encode", |b| { 15 | b.iter(|| { 16 | let _ = tm.encode( 17 | "0100100100100001000101001010".to_owned(), 18 | input.clone(), 19 | 0.95, 20 | ); 21 | }) 22 | }); 23 | } 24 | 25 | criterion_group!(benches, encode_main); 26 | criterion_main!(benches); 27 | -------------------------------------------------------------------------------- /rust/benches/encode.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | # Copyright 2025 Adobe 4 | # All Rights Reserved. 5 | # 6 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 7 | # accordance with the terms of the Adobe license agreement accompanying 8 | # it. 9 | 10 | set -euxo pipefail 11 | 12 | python -m timeit \ 13 | -s "from PIL import Image; from trustmark import TrustMark; tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5); image = Image.open('../images/ufo_240.jpg').convert('RGB')" \ 14 | "tm.encode(image, '0100100100100001000101001010', MODE='BINARY')" 15 | -------------------------------------------------------------------------------- /rust/benches/load.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion}; 9 | use trustmark::{Trustmark, Variant, Version}; 10 | 11 | fn load_main(c: &mut Criterion) { 12 | c.bench_function("load", |b| { 13 | b.iter(|| { 14 | let _tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 15 | }) 16 | }); 17 | } 18 | 19 | criterion_group!(benches, load_main); 20 | criterion_main!(benches); 21 | -------------------------------------------------------------------------------- /rust/benches/load.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | # Copyright 2025 Adobe 4 | # All Rights Reserved. 5 | # 6 | # NOTICE: Adobe permits you to use, modify, and distribute this file in 7 | # accordance with the terms of the Adobe license agreement accompanying 8 | # it. 9 | 10 | set -euxo pipefail 11 | 12 | python -m timeit -s 'from trustmark import TrustMark' "tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5)" 13 | -------------------------------------------------------------------------------- /rust/crates/trustmark-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trustmark-cli" 3 | description = "A CLI for trustmark encoding/decoding" 4 | version = "0.1.0" 5 | authors = [ 6 | "John Collomosse ", 7 | "Maurice Fisher ", 8 | "Andrew Halle ", 9 | ] 10 | license = "MIT" 11 | keywords = ["trustmark", "cv", "watermark", "cli"] 12 | edition = "2021" 13 | publish = false 14 | 15 | [[bin]] 16 | name = "trustmark" 17 | path = "src/main.rs" 18 | 19 | [dependencies] 20 | clap = { version = "4.5.20", features = ["derive"] } 21 | image = "0.25.6" 22 | rand = "0.8.5" 23 | trustmark = { path = "../.." } 24 | -------------------------------------------------------------------------------- /rust/crates/trustmark-cli/README.md: -------------------------------------------------------------------------------- 1 | # TrustMark CLI 2 | 3 | The Rust implementation includes a CLI wrapper for the `trustmark` crate. 4 | 5 | ## Installation 6 | 7 | To install the CLI, run this command from the `trustmark/crates/trustmark-cli` directory: 8 | 9 | ``` 10 | cargo install --locked --path . 11 | ``` 12 | 13 | ### Downloading models 14 | 15 | To use the CLI, you must first create a `models` directory and download the models, [if you haven't already done so](../../README.md#download-models). Enter these commands from the `trustmark/rust` directory: 16 | 17 | ``` 18 | mkdir models 19 | cargo xtask fetch-models 20 | ``` 21 | 22 | ## Usage 23 | 24 | View CLI help information by entering this command: 25 | 26 | ``` 27 | trustmark [encode | decode] help 28 | ``` 29 | 30 | The basic command syntax is: 31 | 32 | ``` 33 | trustmark --models [encode | decode] 34 | ``` 35 | 36 | Where `` is the relative path to the directory containing models. 37 | Use the `encode` subcommand to encode a watermark into an image and the `decode` subcommand to decode a watermark from an image. 38 | 39 | ### Encoding watermarks 40 | 41 | To encode a watermark into an image, use the `encode` subcommand: 42 | 43 | ``` 44 | trustmark --models encode [OPTIONS] -i -o 45 | ``` 46 | 47 | Options: 48 | 49 | | Option | Description | Allowed Values | 50 | |--------|--------------|----------------| 51 | | `-i ` | Path to the image to encode. | Relative file path. | 52 | | `-o ` | Path to file in which to save the watermarked image. | Relative file path. | 53 | | `-w, --watermark ` | The watermark (payload) to encode. | Any a binary string such as `0101010101`. Only 0 and 1 characters are allowed. Maximum length is governed by the version selected. Default is a random binary string. | 54 | | `--version ` | The BCH version to encode with. | One of `BCH_SUPER` (default), `BCH_5`, `BCH_4`, or `BCH_3`. | 55 | | `--variant ` | The model variant to encode with. | `Q` (default), `B`, `C`, and `P`. | 56 | | `--quality ` | If the requested output format is JPEG, the output quality to encode. | A number between 0 and 100. The default is 90. | 57 | | `-h, --help` | Display help information. | N/A | 58 | 59 | ### Decoding watermarks 60 | 61 | To decode a watermark from an image, use the `decode` subcommand: 62 | 63 | ``` 64 | trustmark --models decode [OPTIONS] -i 65 | ``` 66 | 67 | | Option | Description | Allowed Values | 68 | |--------|--------------|----------------| 69 | | `-i ` | Path to the image to decode. | Relative file path. | 70 | | `--variant ` | The model variant to decode with. Must match variant used to encode the watermark. | `Q` (default), `B`, `C`, and `P`. | 71 | | `-h, --help` | Display help information. | N/A | 72 | 73 | ## Examples 74 | 75 | To encode a watermark into one of the sample images, run this command from the workspace root: 76 | 77 | ```sh 78 | trustmark -m ./models encode -i ../images/ghost.png -o ../images/ghost_encoded.png 79 | ``` 80 | 81 | Then to decode the watermark from this image, run this command from the workspace root: 82 | 83 | ```sh 84 | trustmark -m ./models decode -i ../images/ghost_encoded.png 85 | ``` 86 | 87 | You'll see something like this in your terminal: 88 | 89 | ``` 90 | Found watermark: 0101111001101001000000011011010100100010011101101101000001101 91 | ``` 92 | -------------------------------------------------------------------------------- /rust/crates/trustmark-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use std::{fs::OpenOptions, path::PathBuf}; 9 | 10 | use clap::{Parser, Subcommand}; 11 | use image::{codecs::jpeg::JpegEncoder, ImageFormat}; 12 | use rand::{distributions::Standard, prelude::Distribution as _}; 13 | use trustmark::{Trustmark, Variant, Version}; 14 | 15 | #[derive(Debug, Parser)] 16 | struct Args { 17 | #[arg(short, long)] 18 | models: PathBuf, 19 | #[command(subcommand)] 20 | command: Command, 21 | } 22 | 23 | #[derive(Debug, Subcommand)] 24 | enum Command { 25 | /// Encode a watermark into an image 26 | Encode { 27 | /// The image to encode. 28 | #[arg(short)] 29 | input: PathBuf, 30 | /// The path to save the watermarked image. 31 | #[arg(short)] 32 | output: PathBuf, 33 | /// The watermark to encode. Defaults to random if not specified. 34 | #[arg(short, long)] 35 | watermark: Option, 36 | /// The BCH version to encode with. Defaults to BchSuper. 37 | #[arg(long)] 38 | version: Option, 39 | /// The model variant to encode with. 40 | #[arg(long)] 41 | variant: Option, 42 | /// If the requested output is JPEG, the quality to use for encoding. 43 | #[arg(long)] 44 | quality: Option, 45 | }, 46 | /// Decode a watermark from an image 47 | Decode { 48 | #[arg(short)] 49 | input: PathBuf, 50 | /// The model variant to decode with. 51 | #[arg(long)] 52 | variant: Option, 53 | }, 54 | } 55 | 56 | impl Command { 57 | /// Extract the version to use from this `Command`. 58 | fn get_version(&self) -> Version { 59 | match self { 60 | Command::Encode { 61 | version: Some(version), 62 | .. 63 | } => *version, 64 | _ => Version::Bch5, 65 | } 66 | } 67 | 68 | /// Extract the variant to use from this `Command`. 69 | fn get_variant(&self) -> Variant { 70 | match self { 71 | Command::Encode { 72 | variant: Some(variant), 73 | .. 74 | } => *variant, 75 | Command::Decode { 76 | variant: Some(variant), 77 | .. 78 | } => *variant, 79 | _ => Variant::Q, 80 | } 81 | } 82 | } 83 | 84 | /// Generate a random watermark with as many bits as specified by `bits`. 85 | /// 86 | /// # Example 87 | /// 88 | /// ```rust 89 | /// # fn main() { 90 | /// println!("{}", gen_watermark(4)); // 1010 91 | /// # } 92 | /// ``` 93 | fn gen_watermark(bits: usize) -> String { 94 | let mut rng = rand::thread_rng(); 95 | let v: Vec = Standard.sample_iter(&mut rng).take(bits).collect(); 96 | v.into_iter() 97 | .map(|bit| if bit { '1' } else { '0' }) 98 | .collect() 99 | } 100 | 101 | fn main() { 102 | let args = Args::parse(); 103 | let tm = Trustmark::new( 104 | &args.models, 105 | args.command.get_variant(), 106 | args.command.get_version(), 107 | ) 108 | .unwrap(); 109 | match args.command { 110 | Command::Encode { 111 | input, 112 | output, 113 | watermark, 114 | version, 115 | quality, 116 | .. 117 | } => { 118 | let input = image::open(input).unwrap(); 119 | let watermark = watermark.unwrap_or_else(|| { 120 | gen_watermark(version.unwrap_or(Version::Bch5).data_bits().into()) 121 | }); 122 | let encoded = tm.encode(watermark.clone(), input, 0.95).unwrap(); 123 | 124 | let format = ImageFormat::from_path(&output).unwrap(); 125 | match format { 126 | // JPEG encoding can make visual artifacts worse, so we encode with a higher 127 | // quality than the default (or the quality requested by the user). 128 | ImageFormat::Jpeg => { 129 | let quality = quality.unwrap_or(90); 130 | let mut writer = OpenOptions::new() 131 | .write(true) 132 | .create(true) 133 | .truncate(true) 134 | .open(&output) 135 | .unwrap(); 136 | let encoder = JpegEncoder::new_with_quality(&mut writer, quality); 137 | encoded.to_rgb8().write_with_encoder(encoder).unwrap(); 138 | } 139 | _ => { 140 | encoded.to_rgba8().save(&output).unwrap(); 141 | } 142 | } 143 | } 144 | Command::Decode { input, .. } => { 145 | let input = image::open(input).unwrap(); 146 | match tm.decode(input) { 147 | Ok(decoded) => println!("Found watermark: {decoded}"), 148 | Err(trustmark::Error::CorruptWatermark) => { 149 | println!("Corrupt or missing watermark") 150 | } 151 | err => panic!("{err:?}"), 152 | } 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /rust/crates/xtask/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xtask" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | clap = { version = "4.5.20", features = ["derive"] } 9 | ureq = "2.10.1" 10 | -------------------------------------------------------------------------------- /rust/crates/xtask/src/main.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use std::{ 9 | fs::OpenOptions, 10 | io::{self, Read as _}, 11 | }; 12 | 13 | use clap::Parser; 14 | 15 | #[derive(Debug, Parser)] 16 | enum Args { 17 | FetchModels, 18 | } 19 | 20 | fn main() { 21 | let args = Args::parse(); 22 | match args { 23 | Args::FetchModels => fetch_models(), 24 | } 25 | } 26 | 27 | /// Fetch all known models. 28 | fn fetch_models() { 29 | fetch_model("decoder_Q.onnx"); 30 | fetch_model("encoder_Q.onnx"); 31 | fetch_model("decoder_P.onnx"); 32 | fetch_model("encoder_P.onnx"); 33 | fetch_model("decoder_B.onnx"); 34 | fetch_model("encoder_B.onnx"); 35 | fetch_model("decoder_C.onnx"); 36 | fetch_model("encoder_C.onnx"); 37 | } 38 | 39 | /// Fetch a single model identified by `filename`. 40 | /// 41 | /// Models are fetched from a hardcoded CDN URL. 42 | fn fetch_model(filename: &str) { 43 | let root = "https://cc-assets.netlify.app/watermarking/trustmark-models"; 44 | let model_url = format!("{root}/{filename}",); 45 | let mut decoder = ureq::get(&model_url) 46 | .call() 47 | .unwrap() 48 | .into_reader() 49 | .take(100_000_000); 50 | let mut file = OpenOptions::new() 51 | .write(true) 52 | .create_new(true) 53 | .open(format!("models/{filename}")) 54 | .unwrap(); 55 | io::copy(&mut decoder, &mut file).unwrap(); 56 | } 57 | -------------------------------------------------------------------------------- /rust/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/rust/models/.gitkeep -------------------------------------------------------------------------------- /rust/src/bits.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use std::{fmt::Display, str::FromStr}; 9 | 10 | use ndarray::{Array1, ArrayD, Axis}; 11 | use ort::{TensorValueType, Value}; 12 | 13 | const VERSION_BITS: u16 = 4; 14 | 15 | mod bch; 16 | 17 | #[derive(Debug)] 18 | pub(super) struct Bits(String); 19 | 20 | /// Error type for the `bits` module. 21 | #[derive(Debug, thiserror::Error)] 22 | pub enum Error { 23 | /// Something went wrong while doing inference. 24 | #[error("onnx error: {0}")] 25 | Ort(#[from] ort::Error), 26 | 27 | /// A character was encounted that was not a '0' or a '1'. Strings that specify bitstrings 28 | /// should only use these two characters. 29 | #[error("allowed chars are '0' and '1'")] 30 | InvalidChar, 31 | 32 | /// A bitstring has too many bits to fit in the requested version's data payload. 33 | #[error( 34 | "input bitstring ({bits} bits) has more bits than version allows ({version_allows} bits)" 35 | )] 36 | InvalidDataLength { version_allows: usize, bits: usize }, 37 | 38 | /// A bitstring of a certain length is required, and a different length was encountered. 39 | #[error("must be of length 100")] 40 | InvalidLength, 41 | 42 | /// Model input/output did not have the expected dimensions. 43 | #[error("invalid dimensions")] 44 | InvalidDim, 45 | 46 | /// String does not represent a known version. 47 | #[error("invalid version")] 48 | InvalidVersion, 49 | 50 | /// Watermark is missing or corrupted. 51 | /// 52 | /// Either the image did not have a valid watermark, or too many transmission errors/image 53 | /// artifacts occurred such that the watermark is no longer recoverable. 54 | #[error("corrupt watermark")] 55 | CorruptWatermark, 56 | } 57 | 58 | impl Bits { 59 | /// Constructs a `Bits`, adding in the additional error correction and schema bits. 60 | pub(super) fn apply_error_correction_and_schema( 61 | mut input: String, 62 | version: Version, 63 | ) -> Result { 64 | let data_bits: usize = version.data_bits().into(); 65 | 66 | if input.chars().any(|c| c != '0' && c != '1') { 67 | return Err(Error::InvalidChar); 68 | } 69 | 70 | if input.len() > data_bits { 71 | return Err(Error::InvalidDataLength { 72 | bits: input.len(), 73 | version_allows: data_bits, 74 | }); 75 | } 76 | 77 | // pad the input 78 | input.push_str(&"0".repeat(data_bits - input.len() + (8 - data_bits % 8))); 79 | 80 | // pack the input into bytes 81 | let data: Vec = input 82 | .as_bytes() 83 | .chunks(8) 84 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap()) 85 | .collect(); 86 | 87 | // calculate the error correction bits 88 | let mut ecc_state = bch::bch_init(version.allowed_bit_flips() as u32, bch::POLYNOMIAL); 89 | let ecc = bch::bch_encode(&mut ecc_state, &data); 90 | 91 | // form a bitstring from the error correction bits 92 | let mut error_correction: String = ecc 93 | .iter() 94 | .map(|byte| format!("{byte:08b}")) 95 | .collect::>() 96 | .join(""); 97 | 98 | // split off unneeded padding 99 | input.truncate(data_bits); 100 | error_correction.truncate(version.ecc_bits().into()); 101 | 102 | // form the encoded string 103 | Ok(Self(format!( 104 | "{input}{error_correction}{}", 105 | version.bitstring() 106 | ))) 107 | } 108 | 109 | /// Get the data out of a `Bits` by removing the error correction bits. 110 | pub(super) fn get_data(self) -> String { 111 | let version = self.get_version(); 112 | let Self(mut s) = self; 113 | s.truncate(version.data_bits().into()); 114 | s 115 | } 116 | 117 | /// Get the version from the bits. 118 | pub(super) fn get_version(&self) -> Version { 119 | match &self.0[98..100] { 120 | "00" => Version::BchSuper, 121 | "01" => Version::Bch5, 122 | "10" => Version::Bch4, 123 | "11" => Version::Bch3, 124 | _ => unreachable!(), 125 | } 126 | } 127 | 128 | /// Construct a `Bits` from a bitstring. 129 | /// 130 | /// This function checks for bitflips in the bitstring using the error-correcting bits, and 131 | /// corrects them if there are fewer bitflips than are supported by the version. As a last 132 | /// resort, this function checks for bitflips in the version identifier by trying all possible 133 | /// versions. 134 | fn new(s: String) -> Result { 135 | if s.chars().any(|c| c != '0' && c != '1') { 136 | return Err(Error::InvalidChar); 137 | } 138 | 139 | if s.len() != 100 { 140 | return Err(Error::InvalidLength); 141 | } 142 | 143 | let version: Version = Version::from_bitstring(&s[96..]).unwrap_or_default(); 144 | 145 | if let Ok(bits) = Bits::new_with_version(&s, version) { 146 | Ok(bits) 147 | } else { 148 | let mut versions = vec![ 149 | Version::Bch3, 150 | Version::Bch4, 151 | Version::Bch5, 152 | Version::BchSuper, 153 | ]; 154 | versions.retain(|v| *v != version); 155 | let mut res = None; 156 | for version in versions { 157 | res = Some(Bits::new_with_version(&s, version)); 158 | if res.as_ref().unwrap().is_ok() { 159 | return res.unwrap(); 160 | } 161 | } 162 | res.unwrap() 163 | } 164 | } 165 | 166 | fn new_with_version(s: &str, version: Version) -> Result { 167 | let data_bits: usize = version.data_bits().into(); 168 | let ecc_bits: usize = version.ecc_bits().into(); 169 | 170 | let mut data = s[..data_bits].to_string(); 171 | let mut ecc = s[data_bits..data_bits + ecc_bits].to_string(); 172 | 173 | // pad 174 | data.push_str(&"0".repeat(data_bits - data.len() + (8 - data_bits % 8))); 175 | ecc.push_str(&"0".repeat(ecc_bits - ecc.len() + (8 - ecc_bits % 8))); 176 | 177 | // pack into bytes 178 | let mut data: Vec = data 179 | .as_bytes() 180 | .chunks(8) 181 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap()) 182 | .collect(); 183 | let ecc: Vec = ecc 184 | .as_bytes() 185 | .chunks(8) 186 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap()) 187 | .collect(); 188 | 189 | // validate and correct 190 | let mut ecc_state = bch::bch_init(version.allowed_bit_flips() as u32, bch::POLYNOMIAL); 191 | let bitflips = bch::bch_decode(&mut ecc_state, &mut data, &ecc); 192 | 193 | if bitflips > version.allowed_bit_flips() { 194 | return Err(Error::CorruptWatermark); 195 | } 196 | 197 | // unpack data and ecc 198 | let mut data: String = data 199 | .iter() 200 | .map(|byte| format!("{byte:08b}")) 201 | .collect::>() 202 | .join(""); 203 | let mut ecc: String = ecc 204 | .iter() 205 | .map(|byte| format!("{byte:08b}")) 206 | .collect::>() 207 | .join(""); 208 | data.truncate(data_bits); 209 | ecc.truncate(ecc_bits); 210 | 211 | Ok(Bits(format!("{data}{ecc}{}", version.bitstring()))) 212 | } 213 | } 214 | 215 | impl From for ort::Value> { 216 | fn from(Bits(s): Bits) -> Self { 217 | let floats: Vec = s 218 | .chars() 219 | .map(|c| match c { 220 | '0' => 0.0, 221 | '1' => 1.0, 222 | _ => unreachable!(), 223 | }) 224 | .collect(); 225 | 226 | let array = Array1::from(floats); 227 | Value::from_array(array.insert_axis(Axis(0))).unwrap() 228 | } 229 | } 230 | 231 | impl TryFrom> for Bits { 232 | type Error = Error; 233 | 234 | fn try_from(array: ArrayD) -> Result { 235 | if array.shape() != [1, 100] { 236 | return Err(Error::InvalidDim); 237 | } 238 | let array = array.remove_axis(Axis(0)); 239 | let mut s = String::new(); 240 | for bit in array.iter() { 241 | let c = if *bit < 0. { '0' } else { '1' }; 242 | s.push(c); 243 | } 244 | 245 | Bits::new(s) 246 | } 247 | } 248 | 249 | /// The error correction schema 250 | #[derive(Debug, Default, Copy, Clone, PartialEq)] 251 | pub enum Version { 252 | /// Tolerates 8 bit flips 253 | #[default] 254 | BchSuper, 255 | /// Tolerates 5 bit flips 256 | Bch5, 257 | /// Tolerates 4 bit flips 258 | Bch4, 259 | /// Tolerates 3 bit flips 260 | Bch3, 261 | } 262 | 263 | impl Display for Version { 264 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 265 | let s = match self { 266 | Version::BchSuper => "BCH_SUPER", 267 | Version::Bch5 => "BCH_5", 268 | Version::Bch4 => "BCH_4", 269 | Version::Bch3 => "BCH_3", 270 | }; 271 | write!(f, "{s}") 272 | } 273 | } 274 | 275 | impl FromStr for Version { 276 | type Err = Error; 277 | 278 | fn from_str(s: &str) -> Result { 279 | let version = match s { 280 | "BCH_SUPER" => Version::BchSuper, 281 | "BCH_5" => Version::Bch5, 282 | "BCH_4" => Version::Bch4, 283 | "BCH_3" => Version::Bch3, 284 | _ => return Err(Error::InvalidVersion), 285 | }; 286 | 287 | Ok(version) 288 | } 289 | } 290 | 291 | impl Version { 292 | /// Get the number of allowed bit flips for this version. 293 | fn allowed_bit_flips(&self) -> u8 { 294 | match self { 295 | Version::BchSuper => 8, 296 | Version::Bch5 => 5, 297 | Version::Bch4 => 4, 298 | Version::Bch3 => 3, 299 | } 300 | } 301 | 302 | /// Get the number of data bits for this version. 303 | pub fn data_bits(&self) -> u16 { 304 | match self { 305 | Version::BchSuper => 40, 306 | Version::Bch5 => 61, 307 | Version::Bch4 => 68, 308 | Version::Bch3 => 75, 309 | } 310 | } 311 | 312 | /// Get the bitstring which indicates this version. 313 | fn bitstring(&self) -> String { 314 | match self { 315 | Version::BchSuper => "0000".to_owned(), 316 | Version::Bch5 => "0001".to_owned(), 317 | Version::Bch4 => "0010".to_owned(), 318 | Version::Bch3 => "0011".to_owned(), 319 | } 320 | } 321 | 322 | /// Parse a version from a bitstring. 323 | fn from_bitstring(s: &str) -> Result { 324 | Ok(match s { 325 | "0000" => Version::BchSuper, 326 | "0001" => Version::Bch5, 327 | "0010" => Version::Bch4, 328 | "0011" => Version::Bch3, 329 | _ => return Err(Error::InvalidVersion), 330 | }) 331 | } 332 | 333 | /// Get the number of error correcting bits for this version. 334 | fn ecc_bits(&self) -> u16 { 335 | 100 - VERSION_BITS - self.data_bits() 336 | } 337 | } 338 | 339 | #[cfg(test)] 340 | mod tests { 341 | use super::*; 342 | 343 | #[test] 344 | fn get_version() { 345 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned(); 346 | let bits = Bits(input); 347 | assert_eq!(bits.get_version(), Version::Bch5); 348 | } 349 | 350 | #[test] 351 | fn get_data() { 352 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned(); 353 | let bits = Bits(input); 354 | assert_eq!( 355 | bits.get_data(), 356 | "1011011110011000111111000000011111011111011100000110110110111" 357 | ); 358 | } 359 | 360 | #[test] 361 | fn new() { 362 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned(); 363 | let bits = Bits::new(input).unwrap(); 364 | assert_eq!( 365 | bits.get_data(), 366 | "1011011110011000111111000000011111011111011100000110110110111" 367 | ); 368 | } 369 | 370 | #[test] 371 | fn fully_corrupted() { 372 | let input = "0000000000000000000000000000000000000000000100000110110110111000110010101101111010011011000010000001".to_owned(); 373 | let err = Bits::new(input).unwrap_err(); 374 | assert_eq!(err.to_string(), "corrupt watermark"); 375 | } 376 | 377 | #[test] 378 | fn single_bitflip() { 379 | let input = "0011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned(); 380 | let bits = Bits::new(input).unwrap(); 381 | assert_eq!( 382 | bits.get_data(), 383 | "1011011110011000111111000000011111011111011100000110110110111" 384 | ); 385 | } 386 | 387 | #[test] 388 | fn single_bitflip_and_corrupted_version() { 389 | let input = "0011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000011".to_owned(); 390 | let bits = Bits::new(input).unwrap(); 391 | assert_eq!( 392 | bits.get_data(), 393 | "1011011110011000111111000000011111011111011100000110110110111" 394 | ); 395 | } 396 | 397 | #[test] 398 | fn invalid_bitstring() { 399 | let err = Bits::apply_error_correction_and_schema("hello".to_string(), Version::Bch5) 400 | .unwrap_err(); 401 | assert!(matches!(err, Error::InvalidChar)); 402 | } 403 | 404 | #[test] 405 | fn too_long_input() { 406 | let err = 407 | Bits::apply_error_correction_and_schema("0".repeat(200), Version::Bch5).unwrap_err(); 408 | assert!(matches!(err, Error::InvalidDataLength { .. })); 409 | } 410 | 411 | #[test] 412 | fn corrupt() { 413 | let err = Bits::new("1".repeat(100)).unwrap_err(); 414 | assert!(matches!(err, Error::CorruptWatermark)); 415 | } 416 | 417 | #[test] 418 | fn invalid_dim() { 419 | let ar = ArrayD::::zeros(ndarray::IxDyn(&[3])); 420 | let res: Result = ar.try_into(); 421 | assert!(matches!(res.unwrap_err(), Error::InvalidDim)); 422 | } 423 | } 424 | -------------------------------------------------------------------------------- /rust/src/image_processing.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use std::cmp; 9 | 10 | use fast_image_resize::{ResizeAlg, ResizeOptions, Resizer}; 11 | use image::{ 12 | imageops::{self, FilterType}, 13 | DynamicImage, GenericImageView as _, GrayAlphaImage, GrayImage, ImageBuffer, Pixel as _, 14 | Rgb32FImage, RgbImage, Rgba32FImage, RgbaImage, 15 | }; 16 | use ndarray::{s, Array, ArrayD, Axis, ShapeError}; 17 | use ort::TensorValueType; 18 | 19 | use crate::Variant; 20 | 21 | /// Re-normalize a floating point value (either scalar or array) from the range [0,1] to the range 22 | /// [-1, 1]. 23 | macro_rules! convert_from_0_1_to_neg1_1 { 24 | ($f:expr) => { 25 | $f * 2. - 1. 26 | }; 27 | } 28 | 29 | /// Re-normalize a floating point value (either scalar or array) from the range [-1, 1] to the 30 | /// range [0, 1]. 31 | macro_rules! convert_from_neg1_1_to_0_1 { 32 | ($f:expr) => { 33 | ($f + 1.) / 2. 34 | }; 35 | } 36 | 37 | pub(super) struct ModelImage(pub(super) u32, pub(super) Variant, pub(super) DynamicImage); 38 | 39 | /// The error type for the `image_processing` module. 40 | #[derive(Debug, thiserror::Error)] 41 | pub enum Error { 42 | /// Something went wrong during inference. 43 | #[error("onnx error: {0}")] 44 | Ort(#[from] ort::Error), 45 | 46 | /// We were unable to make an `ndarray::Array` of the requested shape. 47 | #[error("shape error: {0}")] 48 | Shape(#[from] ShapeError), 49 | 50 | /// The input array has an unexpected shape. 51 | #[error("invalid shape")] 52 | InvalidShape, 53 | 54 | // We could not create an `ImageBuffer` with the requested array. 55 | #[error("invalid image")] 56 | Image, 57 | 58 | /// We were unable to resize the input image. 59 | #[error("resize error: {0}")] 60 | Resize(#[from] fast_image_resize::ResizeError), 61 | } 62 | 63 | impl TryFrom for ort::Value> { 64 | type Error = Error; 65 | 66 | fn try_from(ModelImage(size, variant, img): ModelImage) -> Result { 67 | let (w, h, xpos, ypos) = center_crop_size_and_offset(variant, &img); 68 | 69 | let options = ResizeOptions::new() 70 | .crop(xpos as f64, ypos as f64, w as f64, h as f64) 71 | .resize_alg(ResizeAlg::Interpolation( 72 | fast_image_resize::FilterType::Bilinear, 73 | )); 74 | let modified_img = resize_img(&img, size, size, options)?; 75 | 76 | let img = modified_img.into_rgb32f().into_vec(); 77 | let array = Array::from(img); 78 | 79 | // The `image` crate normalizes to `[0,1]`. Trustmark wants images normalized to `[-1,1]`. 80 | let array = convert_from_0_1_to_neg1_1!(array); 81 | 82 | let mut array = array 83 | .to_shape([size as usize, size as usize, 3])? 84 | .insert_axis(Axis(3)) 85 | .reversed_axes(); 86 | array.swap_axes(2, 3); 87 | assert_eq!(array.shape(), &[1, 3, size as usize, size as usize]); 88 | Ok(ort::Value::from_array(&array)?) 89 | } 90 | } 91 | 92 | impl TryFrom<(u32, Variant, ArrayD)> for ModelImage { 93 | type Error = Error; 94 | 95 | fn try_from( 96 | (size, variant, mut array): (u32, Variant, ArrayD), 97 | ) -> Result { 98 | let &[1, 3, height, width] = &array.shape().to_owned()[..] else { 99 | return Err(Error::InvalidShape); 100 | }; 101 | array.swap_axes(2, 3); 102 | let array = array.reversed_axes().remove_axis(Axis(3)); 103 | let array = array.to_shape([width * height * 3])?; 104 | 105 | // The `image` crate normalizes to `[0,1]`. Trustmark wants images normalized to `[-1,1]`. 106 | let array = convert_from_neg1_1_to_0_1!(array); 107 | 108 | let image = Rgb32FImage::from_vec(width as u32, height as u32, array.to_vec()) 109 | .ok_or(Error::Image)?; 110 | 111 | Ok(Self(size, variant, image.into())) 112 | } 113 | } 114 | 115 | /// Apply `residual` to the `input`. 116 | /// 117 | /// This function upscales `residual` to be the size of of `input`, then adds `residual` to the 118 | /// `input`. 119 | pub(super) fn apply_residual(input: DynamicImage, residual: DynamicImage) -> DynamicImage { 120 | let has_alpha = input.color().has_alpha(); 121 | let (w, h) = input.dimensions(); 122 | 123 | let applied = { 124 | let input = input.clone().into_rgba32f(); 125 | let mut target = input.clone(); 126 | 127 | let residual = residual.resize_exact(w, h, FilterType::Triangle); 128 | let residual = residual.into_rgba32f(); 129 | 130 | for ((target, residual), original) in target 131 | .pixels_mut() 132 | .zip(residual.pixels()) 133 | .zip(input.pixels()) 134 | { 135 | target.apply2(residual, |x, y| { 136 | let x = convert_from_0_1_to_neg1_1!(x); 137 | let y = convert_from_0_1_to_neg1_1!(y); 138 | 139 | convert_from_neg1_1_to_0_1!(f32::min(x + y, 1.0)) 140 | }); 141 | target[3] = original[3]; 142 | } 143 | 144 | target 145 | }; 146 | 147 | if has_alpha { 148 | let mut input = input.into_rgba32f(); 149 | imageops::replace(&mut input, &applied, 0, 0); 150 | input.into() 151 | } else { 152 | let mut input = input.into_rgb32f(); 153 | let applied = DynamicImage::ImageRgba32F(applied).into_rgb32f(); 154 | imageops::replace(&mut input, &applied, 0, 0); 155 | input.into() 156 | } 157 | } 158 | 159 | /// Return the size and offset of the "center-cropped" image. 160 | /// 161 | /// Returns `(width, height, xpos, ypos)` for the square to crop. 162 | /// 163 | /// For long-skinny images or short-wide images, we want to crop a square image with side length of 164 | /// the shorter side out of the center of the image for the model. 165 | fn center_crop_size_and_offset(variant: Variant, img: &DynamicImage) -> (u32, u32, u32, u32) { 166 | let (width, height) = img.dimensions(); 167 | 168 | if height > width * 2 || width > height * 2 || variant == Variant::P { 169 | let m = cmp::min(height, width); 170 | let offset = (cmp::max(height, width) - m) / 2; 171 | 172 | let xpos; 173 | let ypos; 174 | if height > width { 175 | xpos = 0; 176 | ypos = offset; 177 | } else { 178 | ypos = 0; 179 | xpos = offset; 180 | } 181 | 182 | (m, m, xpos, ypos) 183 | } else { 184 | (width, height, 0, 0) 185 | } 186 | } 187 | 188 | /// Returns a new `DynamicImage`, resized to `width` by `height` with the specified `ResizeOptions`. 189 | fn resize_img( 190 | img: &DynamicImage, 191 | width: u32, 192 | height: u32, 193 | options: ResizeOptions, 194 | ) -> Result { 195 | let mut modified_img = match img { 196 | DynamicImage::ImageLuma8(_) => DynamicImage::ImageLuma8(GrayImage::new(width, height)), 197 | DynamicImage::ImageLumaA8(_) => { 198 | DynamicImage::ImageLumaA8(GrayAlphaImage::new(width, height)) 199 | } 200 | DynamicImage::ImageRgb8(_) => DynamicImage::ImageRgb8(RgbImage::new(width, height)), 201 | DynamicImage::ImageRgba8(_) => DynamicImage::ImageRgba8(RgbaImage::new(width, height)), 202 | DynamicImage::ImageLuma16(_) => DynamicImage::ImageLuma16(ImageBuffer::new(width, height)), 203 | DynamicImage::ImageLumaA16(_) => { 204 | DynamicImage::ImageLumaA16(ImageBuffer::new(width, height)) 205 | } 206 | DynamicImage::ImageRgb16(_) => DynamicImage::ImageRgb16(ImageBuffer::new(width, height)), 207 | DynamicImage::ImageRgba16(_) => DynamicImage::ImageRgba16(ImageBuffer::new(width, height)), 208 | DynamicImage::ImageRgb32F(_) => DynamicImage::ImageRgb32F(Rgb32FImage::new(width, height)), 209 | DynamicImage::ImageRgba32F(_) => { 210 | DynamicImage::ImageRgba32F(Rgba32FImage::new(width, height)) 211 | } 212 | // Technically unreachable, but we error for safety. 213 | _ => return Err(Error::Image), 214 | }; 215 | Resizer::new().resize(img, &mut modified_img, &options)?; 216 | 217 | Ok(modified_img) 218 | } 219 | 220 | /// Applies the mean padding boundary artifact mitigation. 221 | /// 222 | /// Center cropped images have a vertical line problem along the boundary of the residual. This 223 | /// transformation makes this boundary less visible. 224 | pub(super) fn remove_boundary_artifact( 225 | mut residual: ArrayD, 226 | (width, height): (usize, usize), 227 | _variant: Variant, 228 | ) -> ArrayD { 229 | // We're going to replace the border of the residual with the mean and also pad the non-center 230 | // areas with the mean value. 231 | let channel_means: Vec = (0_usize..3) 232 | .map(|i| residual.slice(s![.., i, .., ..]).mean().unwrap()) 233 | .collect(); 234 | 235 | // We want one dimension of the output to be 256 and we we want the aspect ratio of the output 236 | // to match the input image. 237 | let mut mean_padded: ndarray::Array4 = if width > height { 238 | let other = ((width as f32 / height as f32) * 256.0) as usize; 239 | ndarray::Array4::zeros([1, 3, 256_usize, other]) 240 | } else { 241 | let other = (height / width) * 256; 242 | ndarray::Array4::zeros([1, 3, other, 256]) 243 | }; 244 | 245 | // This softens the transition between the residual area and the rest of the image. 246 | let border = 2; 247 | for (i, mean) in channel_means.iter().enumerate() { 248 | residual.slice_mut(s![0, i, ..border, ..]).fill(*mean); 249 | residual.slice_mut(s![0, i, -border.., ..]).fill(*mean); 250 | residual.slice_mut(s![0, i, .., -border..]).fill(*mean); 251 | residual.slice_mut(s![0, i, .., ..border]).fill(*mean); 252 | mean_padded.slice_mut(s![0, i, .., ..]).fill(*mean); 253 | } 254 | 255 | if width > height { 256 | let other = ((width as f32 / height as f32) * 256.0) as usize; 257 | let leftover = (other - 256) / 2; 258 | mean_padded 259 | .slice_mut(s![.., .., .., leftover..(leftover + 256)]) 260 | .assign(&residual); 261 | } else { 262 | let other = ((height as f32 / width as f32) * 256.0) as usize; 263 | let leftover = (other - 256) / 2; 264 | mean_padded 265 | .slice_mut(s![.., .., leftover..(leftover + 256), ..]) 266 | .assign(&residual); 267 | } 268 | 269 | mean_padded.into_dyn() 270 | } 271 | 272 | #[cfg(test)] 273 | mod tests { 274 | use super::*; 275 | 276 | #[test] 277 | fn renormalize_from_0_1() { 278 | assert_eq!(convert_from_0_1_to_neg1_1!(0.), -1.); 279 | assert_eq!(convert_from_0_1_to_neg1_1!(0.5), 0.); 280 | assert_eq!(convert_from_0_1_to_neg1_1!(0.99), 0.98); 281 | } 282 | 283 | #[test] 284 | fn renormalize_from_neg1_1() { 285 | assert_eq!(convert_from_neg1_1_to_0_1!(-1.), 0.); 286 | assert_eq!(convert_from_neg1_1_to_0_1!(0.5), 0.75); 287 | assert_eq!(convert_from_neg1_1_to_0_1!(-0.1), 0.45); 288 | } 289 | 290 | #[test] 291 | fn normal_image() { 292 | let image = DynamicImage::new(100, 110, image::ColorType::L8); 293 | assert_eq!( 294 | center_crop_size_and_offset(Variant::Q, &image), 295 | (100, 110, 0, 0) 296 | ); 297 | } 298 | 299 | #[test] 300 | fn skinny_image() { 301 | let image = DynamicImage::new(10, 100, image::ColorType::L8); 302 | assert_eq!( 303 | center_crop_size_and_offset(Variant::Q, &image), 304 | (10, 10, 0, 45) 305 | ); 306 | } 307 | 308 | #[test] 309 | fn wide_image() { 310 | let image = DynamicImage::new(101, 10, image::ColorType::L8); 311 | assert_eq!( 312 | center_crop_size_and_offset(Variant::Q, &image), 313 | (10, 10, 45, 0) 314 | ); 315 | } 316 | 317 | #[test] 318 | fn always_crop_p() { 319 | let image = DynamicImage::new(100, 110, image::ColorType::L8); 320 | assert_eq!( 321 | center_crop_size_and_offset(Variant::P, &image), 322 | (100, 100, 0, 5) 323 | ); 324 | } 325 | } 326 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | //! # Trustmark 9 | //! 10 | //! An implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) in 11 | //! Rust, as described in: 12 | //! 13 | //! --- 14 | //! 15 | //! **TrustMark - Universal Watermarking for Arbitrary Resolution Images** 16 | //! 17 | //! 18 | //! 19 | //! [Tu Bui]1, [Shruti Agarwal]2, [John Collomosse]1,2 20 | //! 21 | //! 1DECaDE Centre for the Decentralized Digital Economy, University of Surrey, UK.\ 22 | //! 2Adobe Research, San Jose CA. 23 | //! 24 | //! --- 25 | //! 26 | //! This is a re-implementation of the [trustmark] Python library. 27 | //! 28 | //! [Tu Bui]: https://www.surrey.ac.uk/people/tu-bui 29 | //! [Shruti Agarwal]: https://research.adobe.com/person/shruti-agarwal/ 30 | //! [John Collomosse]: https://www.collomosse.com/ 31 | //! [trustmark]: https://pypi.org/project/trustmark/ 32 | //! 33 | //! ## Example 34 | //! 35 | //! ```rust 36 | //! use trustmark::{Trustmark, Version, Variant}; 37 | //! 38 | //! # fn main() { 39 | //! let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 40 | //! let input = image::open("../images/ghost.png").unwrap(); 41 | //! let output = tm.encode("0010101".to_owned(), input, 0.95); 42 | //! # } 43 | //! ``` 44 | use std::path::Path; 45 | 46 | use image::{DynamicImage, GenericImageView as _}; 47 | use ort::{GraphOptimizationLevel, Session}; 48 | 49 | use self::{bits::Bits, image_processing::ModelImage}; 50 | 51 | mod bits; 52 | mod image_processing; 53 | mod model; 54 | 55 | /// A loaded Trustmark model. 56 | pub struct Trustmark { 57 | encoder: Session, 58 | decoder: Session, 59 | version: Version, 60 | variant: Variant, 61 | } 62 | 63 | #[derive(Debug, thiserror::Error)] 64 | pub enum Error { 65 | #[error("watermark is corrupt or missing")] 66 | CorruptWatermark, 67 | #[error("onnx error: {0}")] 68 | Ort(#[from] ort::Error), 69 | #[error("image processing error: {0}")] 70 | ImageProcessing(#[from] image_processing::Error), 71 | #[error("bits processing error: {0}")] 72 | Bits(bits::Error), 73 | #[error("invalid model variant")] 74 | InvalidModelVariant, 75 | } 76 | 77 | impl From for Error { 78 | fn from(value: bits::Error) -> Self { 79 | match value { 80 | bits::Error::CorruptWatermark => Error::CorruptWatermark, 81 | err => Error::Bits(err), 82 | } 83 | } 84 | } 85 | 86 | pub use bits::Version; 87 | pub use model::Variant; 88 | 89 | impl Trustmark { 90 | /// Load a Trustmark model. 91 | pub fn new>( 92 | models: P, 93 | variant: Variant, 94 | version: Version, 95 | ) -> Result { 96 | let encoder = Session::builder()? 97 | .with_optimization_level(GraphOptimizationLevel::Level3)? 98 | .with_intra_threads(8)? 99 | .commit_from_file(models.as_ref().join(variant.encoder_filename()))?; 100 | let decoder = Session::builder()? 101 | .with_optimization_level(GraphOptimizationLevel::Level3)? 102 | .with_intra_threads(8)? 103 | .commit_from_file(models.as_ref().join(variant.decoder_filename()))?; 104 | Ok(Self { 105 | encoder, 106 | decoder, 107 | version, 108 | variant, 109 | }) 110 | } 111 | 112 | /// Encode a watermark into an image. 113 | /// 114 | /// `watermark` is a bitstring encoding the watermark identifier to encode. `img` is the image 115 | /// which will be watermarked. `strength` is a number between 0 and 1 indicating how strong the 116 | /// resulting watermark should be. 0.95 is a normal strength. 117 | pub fn encode( 118 | &self, 119 | watermark: String, 120 | img: DynamicImage, 121 | strength: f32, 122 | ) -> Result { 123 | let (original_width, original_height) = img.dimensions(); 124 | let aspect_ratio = original_width as f32 / original_height as f32; 125 | 126 | // the image is always encoded with size 256x256 127 | let encode_size = 256; 128 | 129 | let input_img: ort::Value> = 130 | ModelImage(encode_size, self.variant, img.clone()).try_into()?; 131 | let bits: ort::Value> = 132 | Bits::apply_error_correction_and_schema(watermark, self.version)?.into(); 133 | let outputs = self.encoder.run(ort::inputs![ 134 | "onnx::Concat_0" => input_img, 135 | "onnx::Gemm_1" => bits, 136 | ]?)?; 137 | let output_img = outputs["image"].try_extract_tensor::()?.to_owned(); 138 | 139 | // Need to calculate and apply the residual. 140 | let input_img: ort::Value> = 141 | ModelImage(encode_size, self.variant, img.clone()).try_into()?; 142 | let residual = (self.variant.strength_multiplier() * strength) 143 | * (output_img - input_img.try_extract_tensor::()?); 144 | 145 | // Residual should be small perturbations. 146 | let mut residual = residual.clamp(-0.2, 0.2); 147 | if (self.variant == Variant::Q && !(0.5..=2.0).contains(&aspect_ratio)) 148 | || self.variant == Variant::P 149 | { 150 | residual = image_processing::remove_boundary_artifact( 151 | residual, 152 | (original_width as usize, original_height as usize), 153 | self.variant, 154 | ); 155 | } 156 | 157 | let ModelImage(_, _, residual) = (encode_size, self.variant, residual).try_into()?; 158 | 159 | Ok(image_processing::apply_residual(img, residual)) 160 | } 161 | 162 | /// Decode a watermark from an image. 163 | pub fn decode(&self, img: DynamicImage) -> Result { 164 | // P variant has a smaller decode size 165 | let decode_size = if self.variant == Variant::P { 224 } else { 256 }; 166 | 167 | let img: ort::Value> = 168 | ModelImage(decode_size, self.variant, img).try_into()?; 169 | let outputs = self.decoder.run(ort::inputs![ 170 | "image" => img, 171 | ]?)?; 172 | let watermark = outputs["output"].try_extract_tensor::()?.to_owned(); 173 | let watermark: Bits = watermark.try_into()?; 174 | Ok(watermark.get_data()) 175 | } 176 | } 177 | 178 | #[cfg(test)] 179 | mod tests { 180 | use super::*; 181 | 182 | #[test] 183 | fn loading_models() { 184 | Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 185 | } 186 | 187 | fn roundtrip(path: impl AsRef) { 188 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap(); 189 | let input = image::open(path.as_ref()).unwrap(); 190 | let watermark = "1011011110011000111111000000011111011111011100000110110110111".to_owned(); 191 | let encoded = tm.encode(watermark.clone(), input, 0.95).unwrap(); 192 | encoded.to_rgba8().save("./test.png").unwrap(); 193 | let input = image::open("./test.png").unwrap(); 194 | let decoded = tm.decode(input).unwrap(); 195 | assert_eq!(watermark, decoded); 196 | } 197 | 198 | #[test] 199 | fn roundtrip_ghost() { 200 | roundtrip("../images/ghost.png"); 201 | } 202 | 203 | #[test] 204 | fn roundtrip_ufo() { 205 | roundtrip("../images/ufo_240.jpg"); 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /rust/src/model.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Adobe 2 | // All Rights Reserved. 3 | // 4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in 5 | // accordance with the terms of the Adobe license agreement accompanying 6 | // it. 7 | 8 | use std::{fmt::Display, str::FromStr}; 9 | 10 | use crate::Error; 11 | 12 | /// The model variant to load. 13 | #[derive(Copy, Clone, Debug, PartialEq)] 14 | pub enum Variant { 15 | B, 16 | C, 17 | P, 18 | Q, 19 | } 20 | 21 | impl Variant { 22 | pub(super) fn encoder_filename(&self) -> String { 23 | let suffix = match self { 24 | Variant::B => "B", 25 | Variant::C => "C", 26 | Variant::P => "P", 27 | Variant::Q => "Q", 28 | }; 29 | 30 | format!("encoder_{suffix}.onnx") 31 | } 32 | 33 | pub(super) fn decoder_filename(&self) -> String { 34 | let suffix = match self { 35 | Variant::B => "B", 36 | Variant::C => "C", 37 | Variant::P => "P", 38 | Variant::Q => "Q", 39 | }; 40 | 41 | format!("decoder_{suffix}.onnx") 42 | } 43 | 44 | pub(super) fn strength_multiplier(&self) -> f32 { 45 | match self { 46 | Variant::P => 1.25, 47 | _ => 1., 48 | } 49 | } 50 | } 51 | 52 | impl Display for Variant { 53 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 54 | let s = match self { 55 | Variant::B => "B", 56 | Variant::C => "C", 57 | Variant::P => "P", 58 | Variant::Q => "Q", 59 | }; 60 | 61 | f.write_str(s) 62 | } 63 | } 64 | 65 | impl FromStr for Variant { 66 | type Err = Error; 67 | 68 | fn from_str(s: &str) -> Result { 69 | Ok(match s { 70 | "B" => Variant::B, 71 | "C" => Variant::C, 72 | "P" => Variant::P, 73 | "Q" => Variant::Q, 74 | _ => return Err(Error::InvalidModelVariant), 75 | }) 76 | } 77 | } 78 | --------------------------------------------------------------------------------