import requests
import base64
import json
import time
import os
import hashlib
from Crypto.Cipher import AES, PKCS1_OAEP
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from Crypto.Hash import SHA256

# --- Configuration ---
SERVER_URL = "http://127.0.0.1:8000"
INSTANCE_ID = "YOUR_INSTANCE_ID"
HWID = "HWID-XXXX"
CARD_KEY = "KEY-XXXX"
GCM_NONCE_SIZE = 12
GCM_TAG_SIZE = 16

def get_server_public_key():
    resp = requests.get(f"{SERVER_URL}/api/v2/secure/public_key")
    resp.raise_for_status()
    return RSA.import_key(resp.json()["public_key"])

def encrypt_aes_gcm(key, plaintext):
    cipher = AES.new(key, AES.MODE_GCM)
    ciphertext, tag = cipher.encrypt_and_digest(plaintext)
    return ciphertext, cipher.nonce, tag

def decrypt_aes_gcm(key, ciphertext, nonce, tag):
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
    return cipher.decrypt_and_verify(ciphertext, tag)

def main():
    print(f"[*] Connecting to {SERVER_URL}...")
    
    # 1. Get Public Key
    try:
        pub_key = get_server_public_key()
        print("[+] Got Server Public Key")
    except Exception as e:
        print(f"[-] Failed to get public key: {e}")
        return

    # 2. Prepare Verify Payload
    aes_key = get_random_bytes(32)
    payload = {"card_key": CARD_KEY, "instance_id": INSTANCE_ID, "hwid": HWID, "timestamp": time.time()}
    
    # Encrypt Data with AES
    data_bytes = json.dumps(payload).encode()
    ciphertext, nonce, tag = encrypt_aes_gcm(aes_key, data_bytes)
    # Format: Nonce + Tag + Ciphertext
    encrypted_data = base64.b64encode(nonce + tag + ciphertext).decode()
    
    # Encrypt AES Key with RSA
    cipher_rsa = PKCS1_OAEP.new(pub_key, hashAlgo=SHA256)
    encrypted_key = base64.b64encode(cipher_rsa.encrypt(aes_key)).decode()
    
    # 3. Send Verify Request
    print("[*] Sending Verify Request...")
    resp = requests.post(f"{SERVER_URL}/api/v2/secure/verify", json={"encrypted_key": encrypted_key, "encrypted_data": encrypted_data})
    
    if resp.status_code != 200:
        print(f"[-] Verify failed: {resp.text}")
        return
        
    # 4. Decrypt Response
    resp_json = resp.json()
    enc_payload = base64.b64decode(resp_json["payload"])
    
    nonce_resp = enc_payload[:GCM_NONCE_SIZE]
    tag_resp = enc_payload[GCM_NONCE_SIZE:GCM_NONCE_SIZE+GCM_TAG_SIZE]
    ciphertext_resp = enc_payload[GCM_NONCE_SIZE+GCM_TAG_SIZE:]
    
    try:
        plaintext_resp = decrypt_aes_gcm(aes_key, ciphertext_resp, nonce_resp, tag_resp)
        data = json.loads(plaintext_resp)
        print("[+] Verify Success!")
        print(f"    Token: {data['token'][:20]}...")
        print(f"    File: {data['file_info']['version']} ({data['file_info']['size']} bytes)")
    except Exception as e:
        print(f"[-] Response decryption failed: {e}")
        return
        
    # 5. Download File
    token = data["token"]
    file_key = bytes.fromhex(data["file_key"])
    download_url = f"{SERVER_URL}{data['download_url'].split('?')[0]}"
    
    print("[*] Downloading File...")
    # Stream download
    with requests.get(download_url, params={"token": token}, stream=True) as r:
        if r.status_code != 200:
            print(f"[-] Download failed: {r.text}")
            return
            
        # Decrypt Stream
        # Nonce (12) + Ciphertext + Tag (16)
        
        # We need to read the nonce first (12 bytes)
        # requests.iter_content might give chunks larger than 12 bytes.
        # So we need to handle the initial read carefully.
        
        # Since requests response is a stream, we can just read from raw if stream=True
        # But iter_content is safer for decoding.
        # Let's use a generator approach.
        
        iterator = r.iter_content(chunk_size=4096)
        
        # Helper to read N bytes from iterator
        def read_n(n, existing_buffer, iterator):
            while len(existing_buffer) < n:
                try:
                    chunk = next(iterator)
                    existing_buffer.extend(chunk)
                except StopIteration:
                    break
            data = existing_buffer[:n]
            remaining = existing_buffer[n:]
            return data, remaining

        buffer = bytearray()
        nonce, buffer = read_n(GCM_NONCE_SIZE, buffer, iterator)
        
        if len(nonce) != GCM_NONCE_SIZE:
             print("[-] Invalid file stream (nonce)")
             return
             
        cipher_file = AES.new(file_key, AES.MODE_GCM, nonce=bytes(nonce))
        decrypted_content = bytearray()
        
        # Process remaining buffer and iterator for ciphertext + tag
        # We need to always keep 16 bytes in buffer as potential tag
        
        while True:
            try:
                chunk = next(iterator)
                buffer.extend(chunk)
            except StopIteration:
                break
                
            if len(buffer) > GCM_TAG_SIZE:
                to_decrypt = buffer[:-GCM_TAG_SIZE]
                decrypted_content.extend(cipher_file.decrypt(to_decrypt))
                buffer = buffer[-GCM_TAG_SIZE:]
        
        # Final buffer is the tag
        if len(buffer) != GCM_TAG_SIZE:
             print("[-] Stream ended prematurely (tag missing)")
             return
             
        tag = bytes(buffer)
        
        try:
            cipher_file.verify(tag)
            print("[+] File Decryption & Verification Successful!")
            print(f"    Decrypted Size: {len(decrypted_content)} bytes")
            # print(f"    Preview: {decrypted_content[:50]}")
        except ValueError:
            print("[-] Integrity Check Failed! (Tag mismatch)")

if __name__ == "__main__":
    main()
