#!/usr/bin/python

# Simple strand test for Adafruit Dot Star RGB LED strip.
# This is a basic diagnostic tool, NOT a graphics demo...helps confirm
# correct wiring and tests each pixel's ability to display red, green
# and blue and to forward data down the line.  By limiting the number
# and color of LEDs, it's reasonably safe to power a couple meters off
# USB.  DON'T try that with other code!

import signal
import struct
import socket
import itertools
import sys
import time
import dotstar
import hexdump
import SocketServer as socketserver
import threading

DISPLAY_WIDTH = 144
STRIP_PIXELS = 144
HOST = "0.0.0.0"
PORT = 11223

# Strip allocation in real pixels
strip = dotstar.Adafruit_DotStar(STRIP_PIXELS, 32000000)

brightness = 128
rb = int( 0.40 * brightness )
gb = int( 1.00 * brightness )
bb = int( 0.75 * brightness )

# Setup the basic strip parameters
strip.begin()           # Initialize pins for output
strip.setBrightness(rb,gb,bb) # Limit brightness to ~1/16 duty cycle and white balance it
#strip.setBrightness(255) # Limit brightness to ~1/16 duty cycle and white balance it

# Setup a cleanup handler for if we get killed
def cleanup(signal, frame):
   strip.clear()
   strip.show()
   sys.exit(0)
signal.signal(signal.SIGINT, cleanup)

# A way to iterate groups of values
def grouper(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return itertools.izip_longest(*args, fillvalue=fillvalue)


class MaxMSPJitHandler(socketserver.StreamRequestHandler):
    "One instance per connection.  Override handle(self) to customize action."

    # Disable Nagle Algorithm (no packet coalescing)
    disable_nagle_algorithm = True

    # Variable to track the connection count
    connection_count = 0

    # Scheduler
    timer = None

    def setup(self):
        MaxMSPJitHandler.connection_count += 1
        #print("Opened connection")
        if MaxMSPJitHandler.timer != None:
            MaxMSPJitHandler.timer.cancel()
            MaxMSPJitHandler.timer = None

    def finish(self):
        MaxMSPJitHandler.connection_count -= 1
        #print("Connection closed")
        if MaxMSPJitHandler.connection_count < 1:
            #print("Final connection closed")
            MaxMSPJitHandler.timer = threading.Timer(60.0, MaxMSPJitHandler.timeoutStrip)
            MaxMSPJitHandler.timer.start()

    @staticmethod
    def timeoutStrip():
        MaxMSPJitHandler.timer = None
        strip.clear()
        strip.show()

    def readStruct(self, bytes, fmt=None):
        if bytes:
            data = self.request.recv(bytes)
            if len(data) != bytes:
                raise EOFError('Unexpected EOF at {0} bytes'.format(len(data)) )
            if fmt:
                fs = struct.calcsize(fmt)
                return struct.unpack(fmt, data[:fs])
            return data
        return ""


    def handle(self):
        # self.request is the client connection

        try:
            while 1:
                # This is the only place we expect to get an EOF
                eof_expected = True

                #data = self.request.recv(8192)
                #hexdump.hexdump(data)
                #break

                # Get the header name and the header length
                (hdr, hdrlen) = self.readStruct(8, '!4sL')

                # EOF is no longer expected
                eof_expected = False

                # If the header length is ridiculously large, this is likely reversed.
                # This is a sign we have a funky double header, read the rest in
                if (hdrlen > 0xffff) and ((hdrlen & 0xffff) == 0):
                    (_, hdrlen) = self.readStruct(8, '!4sL')

                # If we have a JIT Matrix, decode it
                if hdr == "JMTX":

                    # Store the beginning time
                    beg_time = time.time()

                    # Read the header data
                    hdrdata = self.readStruct(hdrlen-8, '!LLL32L32LLd')

                    # Unpack the first bit of the tuple
                    (pcount, ptype, dimcount) = hdrdata[0:3]

                    # Unpack the dimensions and dimension strides
                    dims = list(hdrdata[3:35])
                    dimstrides = list(hdrdata[35:67])

                    # The final two elements
                    (datasize, sent_time) = hdrdata[67:69]

                    #print "Type:", ptype
                    #print "Layers:", pcount
                    #print "Dim Count:", dimcount
                    #print "Dims:", dims
                    #print "Dim Strides:", dimstrides
                    #print "Datasize:", datasize

                    # We only accept 8-bit ARGB
                    if ptype == 0 and pcount == 4:
                        num =  min(dims[0], DISPLAY_WIDTH)
                        fdim_size = num*pcount
                        #print "Reading Mappable Data:",fdim_size
                        fdim = self.readStruct(fdim_size, '!'+str(fdim_size)+'B')

                        # Ignore the remaining data
                        #print "Trashing Other Buffer Data:",datasize-fdim_size
                        self.readStruct(datasize-fdim_size)

                        # Step through the pixels and remove the alpha
                        pixel_num = 0
                        for argb in grouper(fdim, 4):
                            strip.setPixelColor(pixel_num, argb[1], argb[2], argb[3])
                            #print pixel_num, argb, hex(strip.getPixelColor(pixel_num))
                            pixel_num += 1

                        # Show the updated strip
                        strip.show()

                    # If we're not an acceptable type, just throw the data away
                    else:
                        print "Unhandled JIT Matrix Type"
                        self.readStruct(datasize)

                    # Get the end time
                    end_time = time.time()


                    # Send the response so the client can compute latency
                    self.request.send( struct.pack('!sddd', 'JMLP', sent_time,
                            beg_time, end_time) )

                    #print "Packet:",'JMLP', sent_time, beg_time, end_time
                else:
                    print "Unknown:", hdr
                    # Read the header data and throw it away
                    self.readStruct(hdrlen-8)

        except socket.error as ex:
            if not eof_expected:
                print "Unexpected socket error, aborting: {0}".format(ex)

        except EOFError as ex:
            if not eof_expected:
                print "Unexpected EOF in stream, aborting: {0}".format(ex)

        self.request.close()

class MaxMSPJitServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
    # Ctrl-C will cleanly kill all spawned threads
    daemon_threads = True

    # Much faster rebinding
    allow_reuse_address = True

    def __init__(self, server_address, RequestHandlerClass):
        socketserver.TCPServer.__init__(self, server_address, RequestHandlerClass)

if __name__ == "__main__":
    server = MaxMSPJitServer((HOST, PORT), MaxMSPJitHandler)
    # terminate with Ctrl-C
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        sys.exit(0)

sys.exit(0)

