asteroid-automator/gamemodel.py
John McCardle 5499aa5e49 Started classifying on-screen objects with unified "analyse_frame" function
Rocks, lives, and missiles are identified by their rectangles from cv2.matchTemplate.
Clusters from SIFT are checked against these detected rectangles.
The remaining objects (like the rotated ship) are classified as "mysteries". These will be target for further analysis (like ship angle determination), or we can just shoot at them :)
2021-12-23 13:57:50 -05:00

226 lines
9 KiB
Python

import gameio
import cv2
import numpy as np
from utility import *
import pointcluster
class GameModel:
"""Platform-independent representation of the game's state."""
def __init__(self, io:gameio.AbstractGameIO):
self.gameio = io
self.asteroids = [
("big", cv2.imread("images/game_assets/rock-big.png", 0)),
("normal", cv2.imread("images/game_assets/rock-normal.png", 0)),
("small", cv2.imread("images/game_assets/rock-small.png", 0))
]
self.ships = [
("ship_off", cv2.imread("images/game_assets/spaceship-off.png", 0)),
("ship_on", cv2.imread("images/game_assets/spaceship-on.png", 0))
]
#self.missile = ("missile", cv2.imread("images/game_assets/missile.png", 0))
self.frame = None
self.cv_template_thresh = 0.6 # reconfigurable at runtime
self.duplicate_dist_thresh = 36
def with_frame(fn):
"""Decorator to process screenshot to cv2 format once upon first requirement, then reuse."""
def inner(self, *args, **kwargs):
if self.frame is None:
#print("Fetching frame.")
sshot = self.gameio.fetch_sshot()
open_cv_image = np.array(sshot)
# Convert RGB to BGR
self.frame = open_cv_image[:, :, ::-1].copy()
self.color_frame = np.copy(self.frame)
self.frame = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
return fn(self, *args, **kwargs)
return inner
def clear_frame(self):
self.prev_frame = frame
self.frame = None
@with_frame
def find_asteroids(self):
asteroid_rects = []
for label, a in self.asteroids:
h, w = a.shape
res = cv2.matchTemplate(self.frame, a, cv2.TM_CCOEFF_NORMED)
loc = np.where( res >= self.cv_template_thresh)
for pt in zip(*loc[::-1]):
if not asteroid_rects or squared_distance(asteroid_rects[-1][0], pt) > self.duplicate_dist_thresh:
asteroid_rects.append((pt, (pt[0] + w, pt[1] + h), label))
return asteroid_rects
@with_frame
def display_results(self, rects = [], pointsets = [], circles = []):
"""Draws results on the current frame for test purposes."""
displayable = np.copy(self.color_frame)
for pt, wh, label in rects:
color = { "big": (255, 0, 0),
"normal": (0, 255, 0),
"small": (0, 0, 255),
"missile": (0, 255, 128),
"ship_on": (0, 0, 128),
"ship_off": (0, 64, 128)}[label]
cv2.rectangle(displayable, pt, wh, color, 1)
cv2.putText(displayable, label, pt,
cv2.FONT_HERSHEY_PLAIN,
1.0, color)
for ps in pointsets:
color = (0, 255, 255)
cv2.polylines(displayable, np.int32([ps]), True, color)
for center, radius, label in circles:
color = (255, 255, 0)
cv2.circle(displayable, np.int32(center), int(radius), color, 1)
cv2.putText(displayable, label, np.int32(center),
cv2.FONT_HERSHEY_PLAIN,
1.0, color)
cv2.imshow("Results", displayable)
cv2.waitKey(0)
@with_frame
def frame_sift(self):
sift = cv2.SIFT_create()
kp_desc = {} # dict of (keypoints, descriptions) for all ship sprites
kp_desc["frame"] = sift.detectAndCompute(self.frame, None)
frame_kp, frame_desc = kp_desc["frame"]
## for label, s in self.ships:
## kp_desc[label] = sift.detectAndCompute(s, None)
## bf = cv2.BFMatcher(cv2.NORM_L1, crossCheck=True)
## matchsets = {}
## for label in kp_desc:
## _, desc = kp_desc[label]
## matchsets[label] = bf.match(frame_desc, desc)
## #return { "matchsets": matchsets,
## # "kp_desc": kp_desc
## # }
ship_rsq = rect_radius_squared(*self.ships[0][1].shape) * 0.85
#print(f"max radius^2: {ship_rsq}")
clusters = pointcluster.cluster_set([k.pt for k in frame_kp], sqrt(ship_rsq))
return clusters
@with_frame
def find_ships(self):
ship_rects = []
for label, a in self.ships:
h, w = a.shape
res = cv2.matchTemplate(self.frame, a, cv2.TM_CCOEFF_NORMED)
loc = np.where( res >= self.cv_template_thresh)
for pt in zip(*loc[::-1]):
if not ship_rects or squared_distance(ship_rects[-1][0], pt) > self.duplicate_dist_thresh:
ship_rects.append((pt, (pt[0] + w, pt[1] + h), label))
return ship_rects
@with_frame
def find_missiles(self):
# Setup SimpleBlobDetector parameters.
params = cv2.SimpleBlobDetector_Params()
# Change thresholds
params.minThreshold = 10;
params.maxThreshold = 200;
# Filter by Area.
params.filterByArea = True
#params.minArea = 1500
params.maxArea = 100
# Filter by Circularity
#params.filterByCircularity = True
#params.minCircularity = 0.1
# Filter by Convexity
params.filterByConvexity = True
params.minConvexity = 0.95
# Filter by Inertia
params.filterByInertia = True
params.minInertiaRatio = 0.4
detector = cv2.SimpleBlobDetector_create(params)
keypoints = detector.detect(cv2.bitwise_not(self.frame)) # inverted black/white frame
#im_with_keypoints = cv2.drawKeypoints(self.frame, keypoints, np.array([]),
# (0,0,255), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
#cv2.imshow("keypoints", im_with_keypoints)
#cv2.waitKey(0)
s = 9 # pixels for the missile
rect_tuple = lambda pt: ((int(pt[0]-s/2),int(pt[1]-s/2)),
(int(pt[0]+s/2), int(pt[1]+s/2)),
"missile")
return [rect_tuple(k.pt) for k in keypoints]
def analyse_frame(self):
rocks = self.find_asteroids()
lives = self.find_ships()
shots = self.find_missiles()
clusters = self.frame_sift()
labeled_objects = rocks + lives + shots
mystery_clusters = []
# TODO: remove these comprehensions and document pretty utility functions.
easy_find = lambda cluster: any(
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
and point_in_rect(cluster.center, (lo[0], lo[1]))
for lo in labeled_objects])
hard_find = lambda cluster: any(
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
and all([point_in_rect(p, (lo[0], lo[1]))
for p in cluster.points])
for lo in labeled_objects])
# Allow me to explain/apologize.
## The first term (cluster.max_distance < ...) stops big point clusters from
## being regarded as smalll objects. (Player ship being matched "inside" a missile)
## The second term (point_in_rect(...)) checks for a "cluster" inside a "rect".
## easy_find just checks the center.
## hard_find checks every point, in case the center is off.
for i, c in enumerate(clusters):
#if easy_find(c): continue
if hard_find(c): continue
mystery_clusters.append(c)
r_circles = [(c.center, c.max_distance or 5, f"mystery_{i}") for i, c in enumerate(mystery_clusters)]
gm.display_results(rects=labeled_objects, circles=r_circles)
if __name__ == '__main__':
import platform
if platform.system() == "Windows":
io = gameio.WindowsGameIO()
# TODO: Detect OSX or show a message of sadness
else:
io = gameio.LinuxGameIO()
#input("Press <enter> to locate the game at the start screen.")
gm = GameModel(io)
# for testing purposes, populating window location at top-left of my screen
# io.loc is None when the title screen isn't found.
# manually setting io.loc crops all screenshots as if the title was found.
import pyscreeze
io.loc = pyscreeze.Box(0, 25, 800, 599)
from pprint import pprint
#input("Press <enter> to detect asteroids on screen.")
a_results = gm.find_asteroids()
print(f"Found {len(a_results)} asteroids")
#for a in a_results:
# print(a[0]) # position tuple
#gm.display_results(results)
s_results = gm.frame_sift()
ship_results = gm.find_ships()
polygons = [c.points for c in s_results]
#circles = [(c.center, c.max_distance, f"cluster_{i}") for i, c in enumerate(s_results)]
r_circles = [(c.center, sqrt(rect_radius_squared(*gm.ships[0][1].shape)), f"cluster_{i}") for i, c in enumerate(s_results)]
missile_results = gm.find_missiles()
#m_circles = [(pt, 10, f"missile_{i}") for i, pt in enumerate(missiles)]
#pprint(a_results+ship_results+missile_results)
gm.display_results(rects=a_results+ship_results+missile_results, pointsets=polygons, circles=r_circles)
gm.analyse_frame()